aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/dat_model.py79
-rw-r--r--modules/postprocessing.py4
-rw-r--r--modules/processing.py12
-rw-r--r--modules/scripts.py17
-rw-r--r--modules/shared_items.py5
-rw-r--r--modules/shared_options.py3
-rw-r--r--modules/txt2img.py30
7 files changed, 128 insertions, 22 deletions
diff --git a/modules/dat_model.py b/modules/dat_model.py
new file mode 100644
index 00000000..495d5f49
--- /dev/null
+++ b/modules/dat_model.py
@@ -0,0 +1,79 @@
+import os
+
+from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerDAT(Upscaler):
+ def __init__(self, user_path):
+ self.name = "DAT"
+ self.user_path = user_path
+ self.scalers = []
+ super().__init__()
+
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
+ self.scalers.append(scaler_data)
+
+ for model in get_dat_models(self):
+ if model.name in opts.dat_enabled_models:
+ self.scalers.append(model)
+
+ def do_upscale(self, img, path):
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load DAT model {path}", exc_info=True)
+ return img
+
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
+ device=self.device,
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="DAT",
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.DAT_tile,
+ tile_overlap=opts.DAT_tile_overlap,
+ )
+
+ def load_model(self, path):
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
+
+
+def get_dat_models(scaler):
+ return [
+ UpscalerData(
+ name="DAT x2",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x3",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
+ scale=3,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x4",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 7449b0dc..f1488232 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
else:
image_data = image_placeholder
- shared.state.assign_current_image(image_data)
-
parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters
@@ -92,6 +90,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
+ shared.state.assign_current_image(pp.image)
+
if save_output:
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
diff --git a/modules/processing.py b/modules/processing.py
index dcc807fe..6b631795 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1029,6 +1029,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = apply_overlay(image, p.paste_to, overlay_image)
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image_after_composite(p, pp)
+ image = pp.image
+
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -1227,8 +1232,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
-
- shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
+ if getattr(self, 'txt2img_upscale', False):
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
+ else:
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
+ shared.total_tqdm.updateTotal(total_steps)
state.job_count = state.job_count * 2
state.processing_has_refined_job_count = True
diff --git a/modules/scripts.py b/modules/scripts.py
index cf938ebb..060069cf 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -262,6 +262,15 @@ class Script:
pass
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ Same as postprocess_image but after inpaint_full_res composite
+ So that it operates on the full image instead of the inpaint_full_res crop region.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -856,6 +865,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image_after_composite(p, pp, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 13fb2814..88f63645 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -8,6 +8,11 @@ def realesrgan_models_names():
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
+def dat_models_names():
+ import modules.dat_model
+ return [x.name for x in modules.dat_model.get_dat_models(None)]
+
+
def postprocessing_scripts():
import modules.scripts
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 63488f4e..74a2a67f 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -97,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}),
+ "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
}))
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d22a1f31..4efcb4c3 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
import modules.scripts
from modules import processing, infotext_utils
-from modules.infotext_utils import create_override_settings_dict
+from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
@@ -64,19 +64,18 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
p.enable_hr = True
p.batch_size = 1
p.n_iter = 1
+ p.txt2img_upscale = True
geninfo = json.loads(generation_info)
- all_seeds = geninfo["all_seeds"]
- all_subseeds = geninfo["all_subseeds"]
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
p.firstpass_image = infotext_utils.image_from_url_text(image_info)
- gallery_index_from_end = len(gallery) - gallery_index
- seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
- subseed = all_subseeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
- p.seed = seed
- p.subseed = subseed
+ parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
+ p.seed = parameters.get('Seed', -1)
+ p.subseed = parameters.get('Variation seed', -1)
+
+ p.override_settings['save_images_before_highres_fix'] = False
with closing(p):
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
@@ -88,18 +87,13 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
new_gallery = []
for i, image in enumerate(gallery):
- fake_image = Image.new(mode="RGB", size=(1, 1))
-
if i == gallery_index:
- already_saved_as = getattr(processed.images[0], 'already_saved_as', None)
- if already_saved_as is not None:
- fake_image.already_saved_as = already_saved_as
- else:
- fake_image = processed.images[0]
+ geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
+ new_gallery.extend(processed.images)
else:
- fake_image.already_saved_as = image["name"]
-
- new_gallery.append(fake_image)
+ fake_image = Image.new(mode="RGB", size=(1, 1))
+ fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
+ new_gallery.append(fake_image)
geninfo["infotexts"][gallery_index] = processed.info