aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorArturo Albacete <aalbacetef@gmail.com>2024-01-20 21:15:57 +0100
committerArturo Albacete <aalbacetef@gmail.com>2024-01-20 21:15:57 +0100
commitd0b65e148bdc3d35f3f8ee38310ba55152ab4880 (patch)
tree1485c4f2df143b4c3d7e958a232dc9ddb2205280 /modules
parent315e40a49c32438551ed6b66138acdf664ecdbc8 (diff)
parentf939bce845ae07536b1c920618743af83e0b01ec (diff)
merge dev
Diffstat (limited to 'modules')
-rw-r--r--modules/dat_model.py79
-rw-r--r--modules/devices.py11
-rw-r--r--modules/extensions.py5
-rw-r--r--modules/images.py2
-rw-r--r--modules/infotext_utils.py8
-rw-r--r--modules/postprocessing.py4
-rw-r--r--modules/processing.py12
-rw-r--r--modules/processing_scripts/seed.py19
-rw-r--r--modules/scripts.py17
-rw-r--r--modules/shared_items.py5
-rw-r--r--modules/shared_options.py3
-rw-r--r--modules/txt2img.py30
-rw-r--r--modules/ui.py4
-rw-r--r--modules/ui_common.py19
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_toprow.py3
16 files changed, 168 insertions, 55 deletions
diff --git a/modules/dat_model.py b/modules/dat_model.py
new file mode 100644
index 00000000..495d5f49
--- /dev/null
+++ b/modules/dat_model.py
@@ -0,0 +1,79 @@
+import os
+
+from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerDAT(Upscaler):
+ def __init__(self, user_path):
+ self.name = "DAT"
+ self.user_path = user_path
+ self.scalers = []
+ super().__init__()
+
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
+ self.scalers.append(scaler_data)
+
+ for model in get_dat_models(self):
+ if model.name in opts.dat_enabled_models:
+ self.scalers.append(model)
+
+ def do_upscale(self, img, path):
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load DAT model {path}", exc_info=True)
+ return img
+
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
+ device=self.device,
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="DAT",
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.DAT_tile,
+ tile_overlap=opts.DAT_tile_overlap,
+ )
+
+ def load_model(self, path):
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
+
+
+def get_dat_models(scaler):
+ return [
+ UpscalerData(
+ name="DAT x2",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x3",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
+ scale=3,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x4",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/devices.py b/modules/devices.py
index 0321d12c..dfffaf24 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -164,7 +164,11 @@ def manual_cast_forward(target_dtype):
@contextlib.contextmanager
def manual_cast(target_dtype):
+ applied = False
for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ continue
+ applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention and has_xpu():
module_type.forward = manual_cast_forward(torch.float32)
@@ -174,8 +178,11 @@ def manual_cast(target_dtype):
try:
yield None
finally:
- for module_type in patch_module_list:
- module_type.forward = module_type.org_forward
+ if applied:
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ module_type.forward = module_type.org_forward
+ delattr(module_type, "org_forward")
def autocast(disable=False):
diff --git a/modules/extensions.py b/modules/extensions.py
index 99e7ee60..04bda297 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -224,13 +224,16 @@ def list_extensions():
# check for requirements
for extension in extensions:
+ if not extension.enabled:
+ continue
+
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
- if not extension.enabled:
+ if not required_extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
diff --git a/modules/images.py b/modules/images.py
index 87a7bf22..b6f2358c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -321,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res
-invalid_filename_chars = '<>:"/\\|?*\n\r\t'
+invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py
index 9a02cdf2..1049c6c3 100644
--- a/modules/infotext_utils.py
+++ b/modules/infotext_utils.py
@@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res):
res['Hires resize-2'] = height
-def parse_generation_parameters(x: str):
+def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
@@ -240,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
returns a dict with field values
"""
+ if skip_fields is None:
+ skip_fields = shared.opts.infotext_skip_pasting
res = {}
@@ -356,8 +358,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
infotext_versions.backcompat(res)
- skip = set(shared.opts.infotext_skip_pasting)
- res = {k: v for k, v in res.items() if k not in skip}
+ for key in skip_fields:
+ res.pop(key, None)
return res
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 7449b0dc..f1488232 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
else:
image_data = image_placeholder
- shared.state.assign_current_image(image_data)
-
parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters
@@ -92,6 +90,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
+ shared.state.assign_current_image(pp.image)
+
if save_output:
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
diff --git a/modules/processing.py b/modules/processing.py
index dcc807fe..6b631795 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1029,6 +1029,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = apply_overlay(image, p.paste_to, overlay_image)
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image_after_composite(p, pp)
+ image = pp.image
+
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -1227,8 +1232,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
-
- shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
+ if getattr(self, 'txt2img_upscale', False):
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
+ else:
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
+ shared.total_tqdm.updateTotal(total_steps)
state.job_count = state.job_count * 2
state.processing_has_refined_job_count = True
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index 2d3cbb97..7a4c0159 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -6,6 +6,7 @@ from modules import scripts, ui, errors
from modules.infotext_utils import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
+from modules import infotext_utils
class ScriptSeed(scripts.ScriptBuiltinUI):
@@ -77,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
p.seed_resize_from_h = seed_resize_from_h
-
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -85,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def copy_seed(gen_info_string: str, index):
res = -1
-
try:
gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError:
+ infotext = gen_info.get('infotexts')[index]
+ gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
+ res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
+ except Exception:
if gen_info_string:
- errors.report(f"Error parsing JSON generation info: {gen_info_string}")
+ errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
return [res, gr.update()]
diff --git a/modules/scripts.py b/modules/scripts.py
index cf938ebb..060069cf 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -262,6 +262,15 @@ class Script:
pass
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ Same as postprocess_image but after inpaint_full_res composite
+ So that it operates on the full image instead of the inpaint_full_res crop region.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -856,6 +865,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image_after_composite(p, pp, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 13fb2814..88f63645 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -8,6 +8,11 @@ def realesrgan_models_names():
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
+def dat_models_names():
+ import modules.dat_model
+ return [x.name for x in modules.dat_model.get_dat_models(None)]
+
+
def postprocessing_scripts():
import modules.scripts
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 63488f4e..74a2a67f 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -97,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}),
+ "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
}))
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d22a1f31..4efcb4c3 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
import modules.scripts
from modules import processing, infotext_utils
-from modules.infotext_utils import create_override_settings_dict
+from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
@@ -64,19 +64,18 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
p.enable_hr = True
p.batch_size = 1
p.n_iter = 1
+ p.txt2img_upscale = True
geninfo = json.loads(generation_info)
- all_seeds = geninfo["all_seeds"]
- all_subseeds = geninfo["all_subseeds"]
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
p.firstpass_image = infotext_utils.image_from_url_text(image_info)
- gallery_index_from_end = len(gallery) - gallery_index
- seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
- subseed = all_subseeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
- p.seed = seed
- p.subseed = subseed
+ parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
+ p.seed = parameters.get('Seed', -1)
+ p.subseed = parameters.get('Variation seed', -1)
+
+ p.override_settings['save_images_before_highres_fix'] = False
with closing(p):
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
@@ -88,18 +87,13 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
new_gallery = []
for i, image in enumerate(gallery):
- fake_image = Image.new(mode="RGB", size=(1, 1))
-
if i == gallery_index:
- already_saved_as = getattr(processed.images[0], 'already_saved_as', None)
- if already_saved_as is not None:
- fake_image.already_saved_as = already_saved_as
- else:
- fake_image = processed.images[0]
+ geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
+ new_gallery.extend(processed.images)
else:
- fake_image.already_saved_as = image["name"]
-
- new_gallery.append(fake_image)
+ fake_image = Image.new(mode="RGB", size=(1, 1))
+ fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
+ new_gallery.append(fake_image)
geninfo["infotexts"][gallery_index] = processed.info
diff --git a/modules/ui.py b/modules/ui.py
index a716a040..ebd33a85 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -532,7 +532,7 @@ def create_ui():
if category == "image":
with gr.Tabs(elem_id="mode_img2img"):
- img2img_selected_tab = gr.State(0)
+ img2img_selected_tab = gr.Number(value=0, visible=False)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
@@ -613,7 +613,7 @@ def create_ui():
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- selected_scale_tab = gr.State(value=0)
+ selected_scale_tab = gr.Number(value=0, visible=False)
with gr.Tabs():
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 1db72092..92b00719 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -63,8 +63,9 @@ def save_files(js_data, images, do_make_zip, index):
import csv
filenames = []
fullfns = []
+ parsed_infotexts = []
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
class MyObject:
def __init__(self, d=None):
if d is not None:
@@ -72,16 +73,14 @@ def save_files(js_data, images, do_make_zip, index):
setattr(self, key, value)
data = json.loads(js_data)
-
p = MyObject(data)
+
path = shared.opts.outdir_save
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
- only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
- only_one = True
images = [images[index]]
start_index = index
@@ -117,10 +116,12 @@ def save_files(js_data, images, do_make_zip, index):
image = image_from_url_text(filedata)
is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
p.batch_index = image_index-1
- fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], [])
+ parsed_infotexts.append(parameters)
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
filenames.append(filename)
@@ -129,12 +130,12 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"], data["sd_model_name"], data["sd_model_hash"]])
+ writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']])
# Make Zip
if do_make_zip:
- zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
- namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]
+ namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index 7a132ac2..ff22a178 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -5,7 +5,7 @@ import modules.infotext_utils as parameters_copypaste
def create_ui():
dummy_component = gr.Label(visible=False)
- tab_index = gr.State(value=0)
+ tab_index = gr.Number(value=0, visible=False)
with gr.Row(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
index 1abc9117..fbe705be 100644
--- a/modules/ui_toprow.py
+++ b/modules/ui_toprow.py
@@ -107,8 +107,9 @@ class Toprow:
)
def interrupt_function():
- if shared.state.job_count > 1 and shared.opts.interrupt_after_current:
+ if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:
shared.state.stop_generating()
+ gr.Info("Generation will stop after finishing this image, click again to stop immediately.")
else:
shared.state.interrupt()