aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--javascript/hints.js1
-rw-r--r--modules/bsrgan_model.py6
-rw-r--r--modules/codeformer_model.py12
-rw-r--r--modules/devices.py15
-rw-r--r--modules/esrgan_model.py9
-rw-r--r--modules/gfpgan_model.py22
-rw-r--r--modules/images.py39
-rw-r--r--modules/img2img.py7
-rw-r--r--modules/processing.py18
-rw-r--r--modules/prompt_parser.py119
-rw-r--r--modules/scunet_model.py8
-rw-r--r--modules/shared.py13
-rw-r--r--modules/textual_inversion/dataset.py7
-rw-r--r--modules/textual_inversion/preprocess.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py2
-rw-r--r--modules/ui.py14
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
-rw-r--r--scripts/xy_grid.py42
19 files changed, 209 insertions, 131 deletions
diff --git a/javascript/hints.js b/javascript/hints.js
index e72e9338..8adcd983 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -47,6 +47,7 @@ titles = {
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
+ "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
"Tiling": "Produce an image that can be tiled.",
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
index e62c6657..3bd80791 100644
--- a/modules/bsrgan_model.py
+++ b/modules/bsrgan_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
@@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_bsrgan)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index a29f3855..e6d9fa4f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net
self.face_helper = face_helper
- self.net.to(devices.device_codeformer)
return net, face_helper
+ def send_model_to(self, device):
+ self.net.to(device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
@@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None:
return np_image
+ self.send_model_to(devices.device_codeformer)
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ self.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- self.net.to(devices.cpu)
+ self.send_model_to(devices.cpu)
return restored_img
diff --git a/modules/devices.py b/modules/devices.py
index ff82f2f6..0158b11f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,8 +1,10 @@
+import contextlib
+
import torch
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors
+# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
@@ -32,8 +34,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = get_optimal_device()
-device_codeformer = cpu if has_mps else device
+device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
def randn(seed, shape):
@@ -57,3 +58,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
+
+def autocast():
+ from modules import shared
+
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 4aed9283..d17e730f 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -6,8 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-from modules import shared, modelloader, images
-from modules.devices import has_mps
+from modules import shared, modelloader, images, devices
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
@@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
@@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index bb30d733..a9452dce 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(shared.device)
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
@@ -37,22 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
- model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
return model
+def send_model_to(model, device):
+ model.gfpgan.to(device)
+ model.face_helper.face_det.to(device)
+ model.face_helper.face_parse.to(device)
+
+
def gfpgan_fix_faces(np_image):
model = gfpgann()
if model is None:
return np_image
+
+ send_model_to(model, devices.device_gfpgan)
+
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
+ model.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- model.gfpgan.to(devices.cpu)
+ send_model_to(model, devices.cpu)
return np_image
@@ -97,11 +107,7 @@ def setup_model(dirname):
return "GFPGAN"
def restore(self, np_image):
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- return np_image
+ return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
diff --git a/modules/images.py b/modules/images.py
index 1a046aca..c2fadab9 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -287,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None:
x = x.replace("[seed]", str(seed))
+ if p is not None:
+ x = x.replace("[steps]", str(p.steps))
+ x = x.replace("[cfg]", str(p.cfg_scale))
+ x = x.replace("[width]", str(p.width))
+ x = x.replace("[height]", str(p.height))
+
+ #currently disabled if using the save button, will work otherwise
+ # if enabled it will cause a bug because styles is not included in the save_files data dictionary
+ if hasattr(p, "styles"):
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
+
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+
+ x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
+ x = x.replace("[date]", datetime.date.today().isoformat())
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
+ x = x.replace("[job_timestamp]", shared.state.job_timestamp)
+
+ # Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
@@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
@@ -306,24 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
- if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
-
- #currently disabled if using the save button, will work otherwise
- # if enabled it will cause a bug because styles is not included in the save_files data dictionary
- if hasattr(p, "styles"):
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
-
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
-
- x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
- x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
- x = x.replace("[job_timestamp]", shared.state.job_timestamp)
-
if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
diff --git a/modules/img2img.py b/modules/img2img.py
index f4455c90..2ff8e261 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+ save_normally = output_dir == ''
+
p.do_not_save_grid = True
- p.do_not_save_samples = True
+ p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
@@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
- processed_image.save(os.path.join(output_dir, filename))
+ if not save_normally:
+ processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
diff --git a/modules/processing.py b/modules/processing.py
index 0a4b6198..6f5599c7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,3 @@
-import contextlib
import json
import math
import os
@@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
- precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
- ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
- with torch.no_grad(), precision_scope("cuda"), ema_scope():
+
+ with torch.no_grad():
p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1:
@@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
- uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
- c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+ with devices.autocast():
+ uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -361,13 +360,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ with devices.autocast():
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+
if state.interrupted:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
+ samples_ddim = samples_ddim.to(devices.dtype)
+
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -386,6 +389,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
image = Image.fromarray(x_sample)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index e811eb9e..99c8ed99 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,20 +1,11 @@
import re
from collections import namedtuple
import torch
+from lark import Lark, Transformer, Visitor
+import functools
import modules.shared as shared
-re_prompt = re.compile(r'''
-(.*?)
-\[
- ([^]:]+):
- (?:([^]:]*):)?
- ([0-9]*\.?[0-9]+)
-]
-|
-(.+)
-''', re.X)
-
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
def get_learned_conditioning_prompt_schedules(prompts, steps):
- res = []
- cache = {}
-
- for prompt in prompts:
- prompt_schedule: list[list[str | int]] = [[steps, ""]]
-
- cached = cache.get(prompt, None)
- if cached is not None:
- res.append(cached)
- continue
-
- for m in re_prompt.finditer(prompt):
- plaintext = m.group(1) if m.group(5) is None else m.group(5)
- concept_from = m.group(2)
- concept_to = m.group(3)
- if concept_to is None:
- concept_to = concept_from
- concept_from = ""
- swap_position = float(m.group(4)) if m.group(4) is not None else None
-
- if swap_position is not None:
- if swap_position < 1:
- swap_position = swap_position * steps
- swap_position = int(min(swap_position, steps))
-
- swap_index = None
- found_exact_index = False
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- prompt_schedule[i][1] += plaintext
-
- if swap_position is not None and swap_index is None:
- if swap_position == end_step:
- swap_index = i
- found_exact_index = True
-
- if swap_position < end_step:
- swap_index = i
-
- if swap_index is not None:
- if not found_exact_index:
- prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
-
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- must_replace = swap_position < end_step
-
- prompt_schedule[i][1] += concept_to if must_replace else concept_from
-
- res.append(prompt_schedule)
- cache[prompt] = prompt_schedule
- #for t in prompt_schedule:
- # print(t)
-
- return res
+ grammar = r"""
+ start: prompt
+ prompt: (emphasized | scheduled | weighted | plain)*
+ !emphasized: "(" prompt ")"
+ | "(" prompt ":" prompt ")"
+ | "[" prompt "]"
+ scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
+ !weighted: "{" weighted_item ("|" weighted_item)* "}"
+ !weighted_item: prompt (":" prompt)?
+ plain: /([^\\\[\](){}:|]|\\.)+/
+ %import common.SIGNED_NUMBER -> NUMBER
+ """
+ parser = Lark(grammar, parser='lalr')
+ def collect_steps(steps, tree):
+ l = [steps]
+ class CollectSteps(Visitor):
+ def scheduled(self, tree):
+ tree.children[-1] = float(tree.children[-1])
+ if tree.children[-1] < 1:
+ tree.children[-1] *= steps
+ tree.children[-1] = min(steps, int(tree.children[-1]))
+ l.append(tree.children[-1])
+ CollectSteps().visit(tree)
+ return sorted(set(l))
+ def at_step(step, tree):
+ class AtStep(Transformer):
+ def scheduled(self, args):
+ if len(args) == 2:
+ before, after, when = (), *args
+ else:
+ before, after, when = args
+ yield before if step <= when else after
+ def start(self, args):
+ def flatten(x):
+ if type(x) == str:
+ yield x
+ else:
+ for gen in x:
+ yield from flatten(gen)
+ return ''.join(flatten(args[0]))
+ def plain(self, args):
+ yield args[0].value
+ def __default__(self, data, children, meta):
+ for child in children:
+ yield from child
+ return AtStep().transform(tree)
+ @functools.cache
+ def get_schedule(prompt):
+ tree = parser.parse(prompt)
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
+ return [get_schedule(prompt) for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 7987ac14..fb64b740 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
@@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None:
return img
- device = shared.device
+ device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
@@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
- device = shared.device
+ device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
diff --git a/modules/shared.py b/modules/shared.py
index 2a599e9c..a7d13b2d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -12,7 +12,7 @@ import modules.interrogate
import modules.memmon
import modules.sd_models
import modules.styles
-from modules.devices import get_optimal_device
+import modules.devices as devices
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -63,7 +64,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
cmd_opts = parser.parse_args()
-device = get_optimal_device()
+
+devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+
+device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
@@ -195,7 +200,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
- "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
@@ -204,7 +209,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index e8394ff6..7c44ea5b 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -9,6 +9,9 @@ from torchvision import transforms
import random
import tqdm
from modules import devices
+import re
+
+re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
@@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
- filename_tokens = [token for token in filename_tokens if token.isalpha()]
+ filename_tokens = os.path.splitext(filename)[0]
+ filename_tokens = re_tag.findall(filename_tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 209e928f..f545a993 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
else:
- caption = ""
+ caption = filename
+ caption = os.path.splitext(caption)[0]
+ caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8686f534..cd9f3498 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
diff --git a/modules/ui.py b/modules/ui.py
index 16432151..20dc8c37 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -386,14 +386,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component]
)
+
def update_token_counter(text, steps):
- prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+ try:
+ prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step,prompt_text in flat_prompts]
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
+
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
@@ -658,7 +666,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>")
+ gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
diff --git a/requirements.txt b/requirements.txt
index d4b337fc..631fe616 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,3 +22,4 @@ clean-fid
resize-right
torchdiffeq
kornia
+lark
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 8a9acf20..fdff2687 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -21,3 +21,4 @@ clean-fid==0.1.29
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
+lark==1.1.2
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 146663b0..1237e754 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -1,5 +1,6 @@
from collections import namedtuple
from copy import copy
+from itertools import permutations
import random
from PIL import Image
@@ -29,6 +30,31 @@ def apply_prompt(p, x, xs):
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i
@@ -60,16 +86,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
-
return x
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
def do_nothing(p, x, xs):
pass
+
def format_nothing(p, opt, x):
return ""
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@@ -82,6 +118,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@@ -131,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
class Script(scripts.Script):
def title(self):
return "X/Y plot"
@@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext.append(val)
valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]