From 03ee67bfd34b9e872b33eb05fef5db83410b16f3 Mon Sep 17 00:00:00 2001 From: WDevelopsWebApps <97454358+WDevelopsWebApps@users.noreply.github.com> Date: Wed, 28 Sep 2022 10:53:40 +0200 Subject: add advanced saving for save button --- modules/images.py | 5 ++++- modules/ui.py | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 9458bf8d..923f81df 100644 --- a/modules/images.py +++ b/modules/images.py @@ -290,7 +290,10 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) - x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) + #currently disabled if using the save button, will work otherwise + # if enabled it will cause a bug because styles is not included in the save_files data dictionary + if hasattr(p, "styles"): + x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) diff --git a/modules/ui.py b/modules/ui.py index 7db8edbd..87a86a45 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,6 +28,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +from modules.images import apply_filename_pattern, get_next_sequence_number # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -90,13 +91,26 @@ def send_gradio_gallery_to_image(x): def save_files(js_data, images, index): - import csv - - os.makedirs(opts.outdir_save, exist_ok=True) - + import csv filenames = [] + #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + data = json.loads(js_data) + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.save_to_dirs + + if save_to_dirs: + dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) + path = os.path.join(opts.outdir_save, dirname) + + os.makedirs(path, exist_ok=True) if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only images = [images[index]] @@ -107,11 +121,18 @@ def save_files(js_data, images, index): writer = csv.writer(file) if at_start: writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - filename_base = str(int(time.time() * 1000)) + file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + if file_decoration != "": + file_decoration = "-" + file_decoration.lower() + file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt) + truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration + filename_base = truncated + + basecount = get_next_sequence_number(path, "") for i, filedata in enumerate(images): - filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" - filepath = os.path.join(opts.outdir_save, filename) + file_number = f"{basecount+i:05}" + filename = file_number + filename_base + ".png" + filepath = os.path.join(path, filename) if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] -- cgit v1.2.1 From e82ea202997cbcd2ab72891cd075d9ba270eb67d Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Fri, 30 Sep 2022 15:26:18 -0500 Subject: Optimize model loader Child classes only get populated to __subclassess__ when they are imported. We don't actually need to import any of them to webui any more, so clean up webUI imports and make sure loader imports children. Also, fix command line paths not actually being passed to the scalers. --- modules/modelloader.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index 1106aeb7..b1721671 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -4,7 +4,6 @@ import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url - from modules import shared from modules.upscaler import Upscaler from modules.paths import script_path, models_path @@ -120,16 +119,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): + sd = shared.script_path + # We can only do this 'magic' method to dynamically load upscalers if they are referenced, + # so we'll try to import any _model.py files before looking in __subclasses__ + modules_dir = os.path.join(sd, "modules") + for file in os.listdir(modules_dir): + if "_model.py" in file: + model_name = file.replace("_model.py", "") + full_model = f"modules.{model_name}_model" + try: + importlib.import_module(full_model) + except: + pass datas = [] + c_o = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): name = cls.__name__ module_name = cls.__module__ module = importlib.import_module(module_name) class_ = getattr(module, name) - cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" opt_string = None try: - opt_string = shared.opts.__getattr__(cmd_name) + if cmd_name in c_o: + opt_string = c_o[cmd_name] except: pass scaler = class_(opt_string) -- cgit v1.2.1 From 8deae077004f0332ca607fc3a5d568b1a4705bec Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Fri, 30 Sep 2022 15:28:37 -0500 Subject: Add ScuNET DeNoiser/Upscaler Q&D Implementation of ScuNET, thanks to our handy model loader. :P https://github.com/cszn/SCUNet --- modules/scunet_model.py | 90 +++++++++++++++ modules/scunet_model_arch.py | 265 +++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 356 insertions(+) create mode 100644 modules/scunet_model.py create mode 100644 modules/scunet_model_arch.py (limited to 'modules') diff --git a/modules/scunet_model.py b/modules/scunet_model.py new file mode 100644 index 00000000..7987ac14 --- /dev/null +++ b/modules/scunet_model.py @@ -0,0 +1,90 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.paths import models_path +from modules.scunet_model_arch import SCUNet as net + + +class UpscalerScuNET(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "ScuNET" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "ScuNET GAN" + self.model_name2 = "ScuNET PSNR" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" + self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pth"]) + scalers = [] + add_model2 = True + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + if name == self.model_name2 or file == self.model_url2: + add_model2 = False + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading ScuNET model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + if add_model2: + scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) + scalers.append(scaler_data2) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + + model = self.load_model(selected_file) + if model is None: + return img + + device = shared.device + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + + img = img.to(device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + device = shared.device + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: + print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) + return None + + model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + + return model + diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py new file mode 100644 index 00000000..972a2639 --- /dev/null +++ b/modules/scunet_model_arch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, + 2).transpose( + 0, 1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting sqaure. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) + h_windows = x.size(1) + w_windows = x.size(2) + # sqaure validation + # assert h_windows == w_windows + + x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) + qkv = self.embedding_layer(x) + q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) + sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') + # Using Attn Mask to distinguish different subwindows. + if self.type != 'W': + attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) + output = rearrange(output, 'h b w p c -> b w p (h c)') + output = self.linear(output) + output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) + + if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2)) + return output + + def relative_embedding(self): + cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 + # negative is allowed + return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] + + +class Block(nn.Module): + def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer Block + """ + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ['W', 'SW'] + self.type = type + if input_resolution <= window_size: + self.type = 'W' + + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer and Conv Block + """ + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ['W', 'SW'] + if self.input_resolution <= self.window_size: + self.type = 'W' + + self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, + self.type, self.input_resolution) + self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) + ) + + def forward(self, x): + conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange('b c h w -> b h w c')(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange('b h w c -> b c h w')(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): + def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): + super(SCUNet, self).__init__() + if config is None: + config = [2, 2, 2, 2, 2, 2, 2] + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[0])] + \ + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[1])] + \ + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[2])] + \ + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 8) + for i in range(config[3])] + + begin += config[3] + self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[4])] + + begin += config[4] + self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[5])] + + begin += config[5] + self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[6])] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + # self.apply(self._init_weights) + + def forward(self, x0): + + h, w = x0.size()[-2:] + paddingBottom = int(np.ceil(h / 64) * 64 - h) + paddingRight = int(np.ceil(w / 64) * 64 - w) + x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) + + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 8428c7a3..a48b995a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,6 +40,7 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -- cgit v1.2.1 From abdbf1de646f007b6d76cfb3f416fdfaadb57903 Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 29 Sep 2022 14:40:47 -0400 Subject: token counters now update when roll artist and style buttons are pressed https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1194#issuecomment-1261203893 --- modules/ui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..5eea1860 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -539,6 +539,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, + _js="roll_artist_txt2img", inputs=[ txt2img_prompt, ], @@ -743,6 +744,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, + _js="roll_artist_img2img", inputs=[ img2img_prompt, ], @@ -753,6 +755,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_style_txt2img", "update_style_img2img"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): button.click( @@ -764,9 +767,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], ) - for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): button.click( fn=apply_styles, + _js=js_func, inputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2], ) -- cgit v1.2.1 From ff8dc1908af088d0ed43fb85baad662733c5ca9c Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 29 Sep 2022 15:47:06 -0400 Subject: fixed token counter for prompt editing --- modules/ui.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5eea1860..6bf28562 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import time import traceback import platform import subprocess as sp +from functools import reduce import numpy as np import torch @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +from modules.prompt_parser import get_learned_conditioning_prompt_schedules # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -345,8 +347,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) -def update_token_counter(text): - tokens, token_count, max_length = model_hijack.tokenize(text) +def update_token_counter(text, steps): + prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step,prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" @@ -364,8 +369,7 @@ def create_toprow(is_img2img): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -396,7 +400,7 @@ def create_toprow(is_img2img): prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") save_style = gr.Button('Create style', elem_id="style_create") - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part): @@ -419,7 +423,7 @@ def setup_progressbar(progressbar, preview, id_part): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) with gr.Row(elem_id='txt2img_progress_row'): @@ -568,9 +572,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): with gr.Column(scale=1): @@ -793,6 +798,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (denoising_strength, "Denoising strength"), ] modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.1 From 3c6a049fc3c6b54ada3736710a7e86663ea7f3d9 Mon Sep 17 00:00:00 2001 From: Liam Date: Fri, 30 Sep 2022 12:12:44 -0400 Subject: consolidated token counter functions --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6bf28562..40c08984 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -543,7 +543,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, - _js="roll_artist_txt2img", + _js="update_txt2img_tokens", inputs=[ txt2img_prompt, ], @@ -749,7 +749,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, - _js="roll_artist_img2img", + _js="update_img2img_tokens", inputs=[ img2img_prompt, ], @@ -760,7 +760,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_style_txt2img", "update_style_img2img"] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): button.click( -- cgit v1.2.1 From 9de1e56e2dbb405213da9c221e0329d27f411691 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 30 Sep 2022 01:44:38 +0100 Subject: add sampler_noise_scheduler_override property --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..1da753a2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -79,7 +79,7 @@ class StableDiffusionProcessing: self.paste_to = None self.color_corrections = None self.denoising_strength: float = 0 - + self.sampler_noise_scheduler_override = None self.ddim_discretize = opts.ddim_discretize self.s_churn = opts.s_churn self.s_tmin = opts.s_tmin @@ -130,7 +130,7 @@ class Processed: self.s_tmin = p.s_tmin self.s_tmax = p.s_tmax self.s_noise = p.s_noise - + self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) -- cgit v1.2.1 From bc38c80cfc83d4e2fc09c02dd49355664c05d15c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 30 Sep 2022 01:46:06 +0100 Subject: add sampler_noise_scheduler_override switch --- modules/sd_samplers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index dff89c09..92522214 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -290,7 +290,10 @@ class KDiffusionSampler: def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): steps, t_enc = setup_img2img_steps(p, steps) - sigmas = self.model_wrap.get_sigmas(steps) + if p.sampler_noise_scheduler_override: + sigmas = p.sampler_noise_scheduler_override(steps) + else: + sigmas = self.model_wrap.get_sigmas(steps) noise = noise * sigmas[steps - t_enc - 1] xi = x + noise @@ -306,7 +309,10 @@ class KDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): steps = steps or p.steps - sigmas = self.model_wrap.get_sigmas(steps) + if p.sampler_noise_scheduler_override: + sigmas = p.sampler_noise_scheduler_override(steps) + else: + sigmas = self.model_wrap.get_sigmas(steps) x = x * sigmas[0] extra_params_kwargs = self.initialize(p) -- cgit v1.2.1 From 4c2478a68a4f11959fe4887d38e0436eac19f97e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 18:30:53 +0100 Subject: add script reload method --- modules/scripts.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 7c3bd5e7..3c14b9e3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -165,3 +165,12 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + +def reload_scripts(basedir): + global scripts_txt2img,scripts_img2img + + scripts_data.clear() + load_scripts(basedir) + + scripts_txt2img = ScriptRunner() + scripts_img2img = ScriptRunner() -- cgit v1.2.1 From 4f8490cd5630823ac44de8b5c5e4325bdbbea7fa Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 18:33:31 +0100 Subject: add restart button --- modules/ui.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..ec6aaa28 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,6 +1002,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): _js='function(){}' ) + def request_restart(): + settings_interface.gradio_ref.do_restart = True + + restart_gradio = gr.Button(value='Restart Gradio and Refresh Scripts') + restart_gradio.click( + fn=request_restart, + inputs=[], + outputs=[], + _js='function(){document.body.innerHTML=\'

Reloading

\';setTimeout(function(){location.reload()},2000)}' + ) + if column is not None: column.__exit__() @@ -1026,7 +1037,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + + settings_interface.gradio_ref = demo + with gr.Tabs() as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid): -- cgit v1.2.1 From 121ed7d36febe94995774973b5edc1ba2ba84aad Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Sat, 1 Oct 2022 14:04:20 -0400 Subject: Add progress bar for SwinIR in cmd I do not know how to add them to the UI... --- modules/swinir_model.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 41fda5a7..9bd454c6 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -5,6 +5,7 @@ import numpy as np import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url +from tqdm import tqdm from modules import modelloader from modules.paths import models_path @@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale): E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) W = torch.zeros_like(E, dtype=torch.half, device=device) - for h_idx in h_idx_list: - for w_idx in w_idx_list: - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) output = E.div_(W) return output -- cgit v1.2.1 From afaa03c5fd05f48ed9c9f15558ea6f0bc4f61628 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 22:43:45 +0100 Subject: add redefinition guard to gradio_routes_templates_response --- modules/ui.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ec6aaa28..fd057916 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1219,12 +1219,13 @@ for filename in sorted(os.listdir(jsdir)): javascript += f"\n" -def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res +if 'gradio_routes_templates_response' not in globals(): + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response -gradio_routes_templates_response = gradio.routes.templates.TemplateResponse -gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.1 From 6048002dade91b82b1ce9fea3c6ff5b5c1f8c990 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 23:10:07 +0100 Subject: Add scope warning to refresh button --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index fd057916..72846a12 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1005,7 +1005,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def request_restart(): settings_interface.gradio_ref.do_restart = True - restart_gradio = gr.Button(value='Restart Gradio and Refresh Scripts') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') restart_gradio.click( fn=request_restart, inputs=[], -- cgit v1.2.1 From 027c5aae5546ff3650347cb3c2b87df4415ab900 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 23:29:26 +0100 Subject: update reloading message style --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 72846a12..7b2359c2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1010,7 +1010,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): fn=request_restart, inputs=[], outputs=[], - _js='function(){document.body.innerHTML=\'

Reloading

\';setTimeout(function(){location.reload()},2000)}' + _js='function(){document.body.innerHTML=\'

Reloading...

\';setTimeout(function(){location.reload()},2000)}' ) if column is not None: -- cgit v1.2.1 From 0aa354bd5e811e2b41b17a3052cf5d4c8190d533 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 00:13:47 +0100 Subject: remove styling from python side --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 7b2359c2..cb859ac4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1010,7 +1010,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): fn=request_restart, inputs=[], outputs=[], - _js='function(){document.body.innerHTML=\'

Reloading...

\';setTimeout(function(){location.reload()},2000)}' + _js='function(){restart_reload()}' ) if column is not None: -- cgit v1.2.1 From cf33268d686986a24f2e04eb615f01ed53bfe308 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:18:42 +0100 Subject: add script body only refresh --- modules/scripts.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 3c14b9e3..788397f5 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -162,10 +162,33 @@ class ScriptRunner: return processed + def reload_sources(self): + for si,script in list(enumerate(self.scripts)): + with open(script.filename, "r", encoding="utf8") as file: + args_from = script.args_from + args_to = script.args_to + filename = script.filename + text = file.read() + + from types import ModuleType + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +def reload_script_body_only(): + scripts_txt2img.reload_sources() + scripts_img2img.reload_sources() + def reload_scripts(basedir): global scripts_txt2img,scripts_img2img -- cgit v1.2.1 From 07e40ad7f23472fc1c781fe1cc6c1ee403413918 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:19:55 +0100 Subject: add custom script body only refresh option --- modules/ui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cb859ac4..eb7c0585 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1012,6 +1012,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[], _js='function(){restart_reload()}' ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='primary') + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[], + _js='function(){}' + ) if column is not None: column.__exit__() -- cgit v1.2.1 From 2deea867814272f1f089b60e9ba8d587c16b2fb1 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:36:30 +0100 Subject: Put reload buttons in row and add secondary style --- modules/ui.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index eb7c0585..963a2c61 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,27 +1002,30 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): _js='function(){}' ) - def request_restart(): - settings_interface.gradio_ref.do_restart = True + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - restart_gradio.click( - fn=request_restart, - inputs=[], - outputs=[], - _js='function(){restart_reload()}' - ) def reload_scripts(): modules.scripts.reload_script_body_only() - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='primary') reload_script_bodies.click( fn=reload_scripts, inputs=[], outputs=[], _js='function(){}' ) + + def request_restart(): + settings_interface.gradio_ref.do_restart = True + + restart_gradio.click( + fn=request_restart, + inputs=[], + outputs=[], + _js='function(){restart_reload()}' + ) if column is not None: column.__exit__() -- cgit v1.2.1 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- modules/devices.py | 3 +- modules/processing.py | 13 +- modules/sd_hijack.py | 324 ++++--------------------- modules/sd_hijack_optimizations.py | 164 +++++++++++++ modules/sd_models.py | 4 +- modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 76 ++++++ modules/textual_inversion/textual_inversion.py | 258 ++++++++++++++++++++ modules/textual_inversion/ui.py | 32 +++ modules/ui.py | 139 +++++++++-- 10 files changed, 717 insertions(+), 299 deletions(-) create mode 100644 modules/sd_hijack_optimizations.py create mode 100644 modules/textual_inversion/dataset.py create mode 100644 modules/textual_inversion/textual_inversion.py create mode 100644 modules/textual_inversion/ui.py (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..ff82f2f6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_codeformer = cpu if has_mps else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..8223423a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") - self.styles: str = styles + self.styles: list = styles or [] self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p) - os.makedirs(p.outpath_samples, exist_ok=True) - os.makedirs(p.outpath_grids, exist_ok=True) + if p.outpath_samples is not None: + os.makedirs(p.outpath_samples, exist_ok=True) + + if p.outpath_grids is not None: + os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) @@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir): - model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] output_images = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 00000000..9c079e57 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,164 @@ +import math +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) * self.scale + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + +def nonlinearity_hijack(x): + # swish + t = torch.sigmoid(x) + x *= t + del t + + return x + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 diff --git a/modules/sd_models.py b/modules/sd_models.py index 2539f14c..5b3dbdc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader +from modules import shared, modelloader, devices from modules.paths import models_path model_dir = "Stable-diffusion" @@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): if not shared.cmd_opts.no_half: model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ac968b2d..ac0bc480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + textinfo = None def interrupt(self): self.interrupted = True @@ -88,7 +89,7 @@ class State: self.current_image_sampling_step = 0 def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? state = State() diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 00000000..7e134a08 --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import tqdm + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + + self.placeholder_token = placeholder_token + + self.size = size + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + image = Image.open(path) + image = image.convert('RGB') + image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + + filename = os.path.basename(path) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) + torchdata = torch.moveaxis(torchdata, 2, 0) + + init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + + self.dataset.append((init_latent, filename_tokens)) + + self.length = len(self.dataset) * repeats + + self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.indexes = None + self.shuffle() + + def shuffle(self): + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i % len(self.dataset) == 0: + self.shuffle() + + index = self.indexes[i % len(self.indexes)] + x, filename_tokens = self.dataset[index] + + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + + return x, text diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 00000000..c0baaace --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,258 @@ +import os +import sys +import traceback + +import torch +import tqdm +import html +import datetime + +from modules import shared, devices, sd_hijack, processing +import modules.textual_inversion.dataset + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.cached_checksum = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + } + + torch.save(embedding_data, filename) + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + +class EmbeddingDatabase: + def __init__(self, embeddings_dir): + self.ids_lookup = {} + self.word_embeddings = {} + self.dir_mtime = None + self.embeddings_dir = embeddings_dir + + def register_embedding(self, embedding, model): + + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, embedding)) + + return embedding + + def load_textual_inversion_embeddings(self): + mt = os.path.getmtime(self.embeddings_dir) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path, map_location="cpu") + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + self.register_embedding(embedding, shared.sd_model) + + for fn in os.listdir(self.embeddings_dir): + try: + fullfn = os.path.join(self.embeddings_dir, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding + + return None + + + +def create_embedding(name, num_vectors_per_token): + init_text = '*' + + cond_model = shared.sd_model.cond_stage_model + embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): + assert embedding_name, 'embedding not selected' + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + embedding.vec.requires_grad = True + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = embedding.step or 0 + if ititial_step > steps: + return embedding, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + embedding.step = i + ititial_step + + if embedding.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + embedding.save(last_saved_file) + + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + embedding.cached_checksum = None + embedding.save(filename) + + return embedding, filename + diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 00000000..ce3677a9 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,32 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion as ti +from modules import sd_hijack, shared + + +def create_embedding(name, nvpt): + filename = ti.create_embedding(name, nvpt) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def train_embedding(*args): + + try: + sd_hijack.undo_optimizations() + + embedding, filename = ti.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + sd_hijack.apply_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..57aef6ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,6 +21,7 @@ import gradio as gr import gradio.utils import gradio.routes +from modules import sd_hijack from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -142,8 +144,8 @@ def save_files(js_data, images, index): return '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func): - def f(*args, **kwargs): +def wrap_gradio_call(func, extra_outputs=None): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled if run_memmon: shared.mem_mon.monitor() @@ -159,7 +161,10 @@ def wrap_gradio_call(func): shared.state.job = "" shared.state.job_count = 0 - res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -179,6 +184,7 @@ def wrap_gradio_call(func): res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" shared.state.interrupted = False + shared.state.job_count = 0 return tuple(res) @@ -187,7 +193,7 @@ def wrap_gradio_call(func): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False) + return "", gr_show(False), gr_show(False), gr_show(False) progress = 0 @@ -219,13 +225,19 @@ def check_progress_call(id_part): else: preview_visibility = gr_show(True) - return f"

{progressbar}

", preview_visibility, image + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result def check_progress_call_initial(id_part): shared.state.job_count = -1 shared.state.current_latent = None shared.state.current_image = None + shared.state.textinfo = None return check_progress_call(id_part) @@ -399,13 +411,16 @@ def create_toprow(is_img2img): return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste -def setup_progressbar(progressbar, preview, id_part): +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress.click( fn=lambda: check_progress_call(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) @@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part): fn=lambda: check_progress_call_initial(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) -def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=txt2img, + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ txt2img_prompt, @@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) img2img_args = dict( - fn=img2img, + fn=wrap_gradio_gpu_call(modules.img2img.img2img), _js="submit_img2img", inputs=[ dummy_component, @@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( - fn=run_extras, + fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): pnginfo_send_to_img2img = gr.Button('Send to img2img') image.change( - fn=wrap_gradio_call(run_pnginfo), + fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) @@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - + with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") @@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Safe as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - + with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks() as textual_inversion_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Group(): + gr.HTML(value="

Create a new embedding

") + + new_embedding_name = gr.Textbox(label="Name") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create", variant='primary') + + with gr.Group(): + gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_embedding = gr.Button(value="Train", variant='primary') + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + nvpt, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (textual_inversion_interface, "Textual inversion", "ti"), (settings_interface, "Settings", "settings"), ] @@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def modelmerger(*args): try: - results = run_modelmerger(*args) + results = modules.extras.run_modelmerger(*args) except Exception as e: print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() #To remove the potentially missing models from the list + modules.sd_models.list_models() # to remove the potentially missing models from the list return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return results -- cgit v1.2.1 From 0114057ad672a581bd0b598870b58b674b1a3624 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:49:42 +0300 Subject: fix incorrect use of glob in modelloader for #1410 --- modules/modelloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index 8c862b42..015aeafa 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -43,7 +43,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in glob.iglob(place + '**/**', recursive=True): - full_path = os.path.join(place, file) + full_path = file if os.path.isdir(full_path): continue if len(ext_filter) != 0: -- cgit v1.2.1 From 0758f6e641b5790ce566a998d43e0ea74a627766 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 17:24:50 +0300 Subject: fix --ckpt option breaking model selection --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b3dbdc7..9259d69e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -69,7 +69,7 @@ def list_models(): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.sd_model_checkpoint = title + shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: -- cgit v1.2.1 From 88ec0cf5571883d84abd09196652b3679e359f2e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 19:40:51 +0300 Subject: fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!) add option to input initialization text for embeddings --- modules/sd_hijack.py | 8 ++++---- modules/textual_inversion/textual_inversion.py | 13 +++++-------- modules/textual_inversion/ui.py | 4 ++-- modules/ui.py | 2 ++ 4 files changed, 13 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fd57e5c5..3fa06242 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) if embedding is None: remade_tokens.append(token) @@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [mult] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c0baaace..0c50161d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -117,24 +117,21 @@ class EmbeddingDatabase: possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: - return None + return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: - return embedding + return embedding, len(ids) - return None + return None, None - -def create_embedding(name, num_vectors_per_token): - init_text = '*' - +def create_embedding(name, num_vectors_per_token, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index ce3677a9..66c43ffb 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti from modules import sd_hijack, shared -def create_embedding(name, nvpt): - filename = ti.create_embedding(name, nvpt) +def create_embedding(name, initialization_text, nvpt): + filename = ti.create_embedding(name, nvpt, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() diff --git a/modules/ui.py b/modules/ui.py index 3b81a4f7..eca50df0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

Create a new embedding

") new_embedding_name = gr.Textbox(label="Name") + initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) with gr.Row(): @@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, + initialization_text, nvpt, ], outputs=[ -- cgit v1.2.1 From 71fe7fa49f5eb1a2c89932a9d217ed153c12fc8b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 19:56:37 +0300 Subject: fix using aaaa-100 embedding when the prompt has aaaa-10000 and you have both aaaa-100 and aaaa-10000 in the directory with embeddings. --- modules/textual_inversion/textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 0c50161d..9d2241ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -57,7 +57,8 @@ class EmbeddingDatabase: first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, embedding)) + + self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) return embedding -- cgit v1.2.1 From 4ec4af6e0b7addeee5221a03f32d117ccdc875d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 20:15:25 +0300 Subject: add checkpoint info to saved embeddings --- modules/textual_inversion/textual_inversion.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9d2241ce..1183aab7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,7 +7,7 @@ import tqdm import html import datetime -from modules import shared, devices, sd_hijack, processing +from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -17,6 +17,8 @@ class Embedding: self.name = name self.step = step self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None def save(self, filename): embedding_data = { @@ -24,6 +26,8 @@ class Embedding: "string_to_param": {"*": self.vec}, "name": self.name, "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, } torch.save(embedding_data, filename) @@ -41,6 +45,7 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum + class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} @@ -96,6 +101,8 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) for fn in os.listdir(self.embeddings_dir): @@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}

""" + checkpoint = sd_models.select_checkpoint() + + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name embedding.cached_checksum = None embedding.save(filename) -- cgit v1.2.1 From 3ff0de2c594b786ef948a89efb1814c59bb42117 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 20:23:40 +0300 Subject: added --disable-console-progressbars to disable progressbars in console disabled printing prompts to console by default, enabled by --enable-console-prompts --- modules/img2img.py | 4 +++- modules/sd_samplers.py | 8 ++++++-- modules/shared.py | 7 +++++-- modules/txt2img.py | 4 +++- 4 files changed, 17 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 03e934e9..f4455c90 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -103,7 +103,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, ) - print(f"\nimg2img: {prompt}", file=shared.progress_print_out) + + if shared.cmd_opts.enable_console_prompts: + print(f"\nimg2img: {prompt}", file=shared.progress_print_out) p.extra_generation_params["Mask blur"] = mask_blur diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 92522214..9316875a 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -77,7 +77,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence) state.sampling_step = 0 - for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break @@ -207,7 +209,9 @@ def extended_trange(sampler, count, *args, **kwargs): state.sampling_steps = count state.sampling_step = 0 - for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 5a591dc9..1bf7a6c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -58,6 +58,9 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) +parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) + cmd_opts = parser.parse_args() device = get_optimal_device() @@ -320,14 +323,14 @@ class TotalTQDM: ) def update(self): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() self._tqdm.update() def updateTotal(self, new_total): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() diff --git a/modules/txt2img.py b/modules/txt2img.py index 5368e4d0..d4406c3c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, ) - print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + if cmd_opts.enable_console_prompts: + print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: -- cgit v1.2.1 From 6365a41f5981efa506dfe4e8fa878b43ca2d8d0c Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Oct 2022 12:58:17 -0500 Subject: Update esrgan_model.py Use alternate ESRGAN Model download path. --- modules/esrgan_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index ea91abfe..4aed9283 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net): class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" - self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" - self.model_name = "ESRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth" + self.model_name = "ESRGAN_4x" self.scalers = [] self.user_path = dirname self.model_path = os.path.join(models_path, self.name) -- cgit v1.2.1 From a1cde7e6468f80584030525a1b07cbf0f4ee42eb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:09:10 +0300 Subject: disabled SD model download after multiple complaints --- modules/sd_models.py | 18 ++++++++---------- modules/textual_inversion/ui.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 9259d69e..9a6b568f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,9 +13,6 @@ from modules.paths import models_path model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -model_name = "sd-v1-4.ckpt" -model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" -user_dir = None CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} @@ -30,12 +27,10 @@ except Exception: pass -def setup_model(dirname): - global user_dir - user_dir = dirname +def setup_model(): if not os.path.exists(model_path): os.makedirs(model_path) - checkpoints_list.clear() + list_models() @@ -45,7 +40,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -106,8 +101,11 @@ def select_checkpoint(): if len(checkpoints_list) == 0: print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) - print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) + if shared.cmd_opts.ckpt is not None: + print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) + print(f" - directory {model_path}", file=sys.stderr) + if shared.cmd_opts.ckpt_dir is not None: + print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 66c43ffb..633037d8 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -22,7 +22,7 @@ def train_embedding(*args): embedding, filename = ti.train_embedding(*args) res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. Embedding saved to {html.escape(filename)} """ return res, "" -- cgit v1.2.1 From 852fd90c0dcda9cb5fbbfdf0c7308ce58034935c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:22:20 +0300 Subject: emergency fix for disabling SD model download after multiple complaints --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 9a6b568f..5f992064 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,8 +45,8 @@ def list_models(): def modeltitle(path, shorthash): abspath = os.path.abspath(path) - if user_dir is not None and abspath.startswith(user_dir): - name = abspath.replace(user_dir, '') + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: -- cgit v1.2.1 From e808096cf641d868f88465515d70d40fc46125d4 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 19:26:06 +0100 Subject: correct indent --- modules/scripts.py | 48 +++++++++++++++++++++++++----------------------- modules/ui.py | 25 ++++++++++++------------- 2 files changed, 37 insertions(+), 36 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 788397f5..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -163,37 +163,39 @@ class ScriptRunner: return processed def reload_sources(self): - for si,script in list(enumerate(self.scripts)): - with open(script.filename, "r", encoding="utf8") as file: - args_from = script.args_from - args_to = script.args_to - filename = script.filename - text = file.read() + for si, script in list(enumerate(self.scripts)): + with open(script.filename, "r", encoding="utf8") as file: + args_from = script.args_from + args_to = script.args_to + filename = script.filename + text = file.read() - from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + from types import ModuleType - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - self.scripts[si] = script_class() - self.scripts[si].filename = filename - self.scripts[si].args_from = args_from - self.scripts[si].args_to = args_to + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + scripts_txt2img.reload_sources() + scripts_img2img.reload_sources() + def reload_scripts(basedir): - global scripts_txt2img,scripts_img2img + global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + scripts_data.clear() + load_scripts(basedir) - scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() + scripts_txt2img = ScriptRunner() + scripts_img2img = ScriptRunner() diff --git a/modules/ui.py b/modules/ui.py index 963a2c61..6b30f84b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1003,12 +1003,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') def reload_scripts(): - modules.scripts.reload_script_body_only() + modules.scripts.reload_script_body_only() reload_script_bodies.click( fn=reload_scripts, @@ -1018,7 +1018,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) def request_restart(): - settings_interface.gradio_ref.do_restart = True + settings_interface.gradio_ref.do_restart = True restart_gradio.click( fn=request_restart, @@ -1234,12 +1234,11 @@ for filename in sorted(os.listdir(jsdir)): if 'gradio_routes_templates_response' not in globals(): - def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio_routes_templates_response = gradio.routes.templates.TemplateResponse - gradio.routes.templates.TemplateResponse = template_response - + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.1 From 91f327f22bb2feb780c424c74723cc0629dc34a1 Mon Sep 17 00:00:00 2001 From: Lopyter Date: Sun, 2 Oct 2022 18:15:31 +0200 Subject: make save to dirs optional for imgs saved from ui --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 1bf7a6c1..785e7af6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), })) options_templates.update(options_section(('upscaling', "Upscaling"), { diff --git a/modules/ui.py b/modules/ui.py index 78a15d83..8912deff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -113,7 +113,7 @@ def save_files(js_data, images, index): p = MyObject(data) path = opts.outdir_save - save_to_dirs = opts.save_to_dirs + save_to_dirs = opts.use_save_to_dirs_for_ui if save_to_dirs: dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) -- cgit v1.2.1 From c4445225f79f1c57afe52358ff4b205864eb7aac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:50:14 +0300 Subject: change wording for options --- modules/shared.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 785e7af6..7246eadc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -170,10 +170,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), + "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), - "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), })) options_templates.update(options_section(('upscaling', "Upscaling"), { -- cgit v1.2.1 From c7543d4940da672d970124ae8f2fec9de7bdc1da Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 22:41:21 +0300 Subject: preprocessing for textual inversion added --- modules/interrogate.py | 1 + modules/textual_inversion/preprocess.py | 75 ++++++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 1 + modules/textual_inversion/ui.py | 14 +++-- modules/ui.py | 36 +++++++++++++ 5 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 modules/textual_inversion/preprocess.py (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index f62a4745..eed87144 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") + class InterrogateModels: blip_model = None clip_model = None diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py new file mode 100644 index 00000000..209e928f --- /dev/null +++ b/modules/textual_inversion/preprocess.py @@ -0,0 +1,75 @@ +import os +from PIL import Image, ImageOps +import tqdm + +from modules import shared, images + + +def preprocess(process_src, process_dst, process_flip, process_split, process_caption): + size = 512 + src = os.path.abspath(process_src) + dst = os.path.abspath(process_dst) + + assert src != dst, 'same directory specified as source and desitnation' + + os.makedirs(dst, exist_ok=True) + + files = os.listdir(src) + + shared.state.textinfo = "Preprocessing..." + shared.state.job_count = len(files) + + if process_caption: + shared.interrogator.load() + + def save_pic_with_caption(image, index): + if process_caption: + caption = "-" + shared.interrogator.generate_caption(image) + else: + caption = "" + + image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) + subindex[0] += 1 + + def save_pic(image, index): + save_pic_with_caption(image, index) + + if process_flip: + save_pic_with_caption(ImageOps.mirror(image), index) + + for index, imagefile in enumerate(tqdm.tqdm(files)): + subindex = [0] + filename = os.path.join(src, imagefile) + img = Image.open(filename).convert("RGB") + + if shared.state.interrupted: + break + + ratio = img.height / img.width + is_tall = ratio > 1.35 + is_wide = ratio < 1 / 1.35 + + if process_split and is_tall: + img = img.resize((size, size * img.height // img.width)) + + top = img.crop((0, 0, size, size)) + save_pic(top, index) + + bot = img.crop((0, img.height - size, size, img.height)) + save_pic(bot, index) + elif process_split and is_wide: + img = img.resize((size * img.width // img.height, size)) + + left = img.crop((0, 0, size, size)) + save_pic(left, index) + + right = img.crop((img.width - size, 0, img.width, size)) + save_pic(right, index) + else: + img = images.resize_image(1, img, size, size) + save_pic(img, index) + + shared.state.nextjob() + + if process_caption: + shared.interrogator.send_blip_to_ram() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1183aab7..d4e250d8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,6 +7,7 @@ import tqdm import html import datetime + from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 633037d8..f19ac5e0 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -2,24 +2,31 @@ import html import gradio as gr -import modules.textual_inversion.textual_inversion as ti +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess from modules import sd_hijack, shared def create_embedding(name, initialization_text, nvpt): - filename = ti.create_embedding(name, nvpt, init_text=initialization_text) + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" +def preprocess(*args): + modules.textual_inversion.preprocess.preprocess(*args) + + return "Preprocessing finished.", "" + + def train_embedding(*args): try: sd_hijack.undo_optimizations() - embedding, filename = ti.train_embedding(*args) + embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) res = f""" Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. @@ -30,3 +37,4 @@ Embedding saved to {html.escape(filename)} raise finally: sd_hijack.apply_optimizations() + diff --git a/modules/ui.py b/modules/ui.py index 8912deff..e7bde53b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -961,6 +961,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row().style(equal_height=False): with gr.Column(): with gr.Group(): + gr.HTML(value="

See wiki for detailed explanation.

") + gr.HTML(value="

Create a new embedding

") new_embedding_name = gr.Textbox(label="Name") @@ -974,6 +976,24 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create", variant='primary') + with gr.Group(): + gr.HTML(value="

Preprocess images

") + + process_src = gr.Textbox(label='Source directory') + process_dst = gr.Textbox(label='Destination directory') + + with gr.Row(): + process_flip = gr.Checkbox(label='Flip') + process_split = gr.Checkbox(label='Split into two') + process_caption = gr.Checkbox(label='Add caption') + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + run_preprocess = gr.Button(value="Preprocess", variant='primary') + with gr.Group(): gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) @@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call): ] ) + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + train_embedding.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", -- cgit v1.2.1 From 6785331e22d6a488fbf5905fab56d7fec867e038 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 22:59:01 +0300 Subject: keep textual inversion dataset latents in CPU memory to save a bit of VRAM --- modules/textual_inversion/dataset.py | 2 ++ modules/textual_inversion/textual_inversion.py | 3 +++ modules/ui.py | 4 ++-- 3 files changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7e134a08..e8394ff6 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,6 +8,7 @@ from torchvision import transforms import random import tqdm +from modules import devices class PersonalizedBase(Dataset): @@ -47,6 +48,7 @@ class PersonalizedBase(Dataset): torchdata = torch.moveaxis(torchdata, 2, 0) init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + init_latent = init_latent.to(devices.cpu) self.dataset.append((init_latent, filename_tokens)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d4e250d8..8686f534 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -212,7 +212,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, with torch.autocast("cuda"): c = cond_model([text]) + + x = x.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x losses[embedding.step % losses.shape[0]] = loss.item() diff --git a/modules/ui.py b/modules/ui.py index e7bde53b..d9d02ece 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,8 +1002,8 @@ def create_ui(wrap_gradio_gpu_call): log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) with gr.Row(): with gr.Column(scale=2): -- cgit v1.2.1 From 166283653cfe7521a422c91e8fb801f3ecb4adc8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 23:18:13 +0300 Subject: remove LDSR warning --- modules/paths.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index ceb80417..606f7d66 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,6 @@ path_dirs = [ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), - (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] -- cgit v1.2.1