aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authoryfszzx <yfszzx@gmail.com>2022-10-16 12:34:05 +0800
committeryfszzx <yfszzx@gmail.com>2022-10-16 12:34:05 +0800
commit5d8c59eee505cf15ec6994d05bb941440d90e44e (patch)
treea87f6e38d4dc33e3f293f10f9cdd1760af93114c /modules
parent763b893f319cee280b86e63025eb55e7c16b02e7 (diff)
parentd41ac174e24e1e7cdcf7b42f2a03cbc6394eb5e5 (diff)
Merge branch 'master' of https://github.com/yfszzx/stable-diffusion-webui-plus
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py43
-rw-r--r--modules/images_history.py6
-rw-r--r--modules/processing.py12
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_models.py58
-rw-r--r--modules/shared.py4
-rw-r--r--modules/textual_inversion/dataset.py32
-rw-r--r--modules/textual_inversion/textual_inversion.py34
-rw-r--r--modules/ui.py107
10 files changed, 187 insertions, 113 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index f34f3788..4ad334a1 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
- model_path, compile_model=True
+ model_path, compile_model=False
)
return model, tags
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 59c7ac6e..4905710e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step)
@@ -246,26 +260,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
with torch.autocast("cuda"):
- cond = entry.cond.to(devices.device)
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), cond)[0]
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
- del cond
+ del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
-
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ mean_loss = losses.mean()
+ if torch.isnan(mean_loss):
+ raise RuntimeError("Loss diverged.")
+ pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
+ "loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
@@ -292,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
p.width = preview_width
p.height = preview_height
else:
- p.prompt = entry.cond_text
+ p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
@@ -313,9 +330,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"""
<p>
-Loss: {losses.mean():.7f}<br/>
+Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(entry.cond_text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/images_history.py b/modules/images_history.py
index 533cf51b..7fd75005 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -197,14 +197,16 @@ def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
return new_file_list, 1
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- if tabname == "txt2img":
+ if opts.outdir_samples != "":
+ dir_name = opts.outdir_samples
+ elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
d = dir_name.split("/")
- dir_name = d[0]
+ dir_name = "/" if dir_name.startswith("/") else d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
diff --git a/modules/processing.py b/modules/processing.py
index a75b9f84..941ae089 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -140,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
@@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_height_truncated = int(scale * self.height)
else:
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
@@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
@@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- decoded_samples = decode_first_stage(self.sd_model, samples)
+ if opts.use_scale_latent_for_hires_fix:
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
+ decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
@@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1.
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c81722a0..984b35c4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -24,7 +24,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3a01c93d..3aa21ec1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,4 +1,4 @@
-import glob
+import collections
import os.path
import sys
from collections import namedtuple
@@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}
+checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ if checkpoint_info not in checkpoints_loaded:
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+ pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ model.load_state_dict(sd, strict=False)
- sd = get_state_dict_from_checkpoint(pl_sd)
+ if shared.cmd_opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
- model.load_state_dict(sd, strict=False)
+ if not shared.cmd_opts.no_half:
+ model.half()
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- if not shared.cmd_opts.no_half:
- model.half()
+ vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
+ vae_file = shared.cmd_opts.vae_path
- vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
+ if os.path.exists(vae_file):
+ print(f"Loading VAE weights from: {vae_file}")
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ model.first_stage_model.load_state_dict(vae_dict)
- if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
- vae_file = shared.cmd_opts.vae_path
+ model.first_stage_model.to(devices.dtype_vae)
- if os.path.exists(vae_file):
- print(f"Loading VAE weights from: {vae_file}")
-
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
-
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
-
- model.first_stage_model.load_state_dict(vae_dict)
-
- model.first_stage_model.to(devices.dtype_vae)
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
+ else:
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ checkpoints_loaded.move_to_end(checkpoint_info)
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
@@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ checkpoints_loaded.clear()
shared.sd_model = load_model()
return shared.sd_model
diff --git a/modules/shared.py b/modules/shared.py
index d41a7ab3..fa30bbb0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -242,6 +243,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -255,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
- 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@@ -283,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..23bb4b6a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
+ self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
if include_cond:
entry.cond_text = self.create_text(filename_text)
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
- self.length = len(self.dataset) * repeats
+ assert len(self.dataset) > 1, "No images have been found in the dataset."
+ self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
+ res = []
- index = self.indexes[i % len(self.indexes)]
- entry = self.dataset[index]
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
- if entry.cond is None:
- entry.cond_text = self.create_text(entry.filename_text)
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
- return entry
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
+
+ return res
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index da0d77a0..2ed345b1 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -88,9 +88,9 @@ class EmbeddingDatabase:
data = []
- if filename.upper().endswith('.PNG'):
+ if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
- if 'sd-ti-embedding' in embed_image.text:
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = "<none>"
last_saved_image = "<none>"
+ embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
if ititial_step > steps:
@@ -251,7 +252,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step)
@@ -262,10 +263,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
-
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
@@ -282,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
@@ -307,7 +308,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p.width = preview_width
p.height = preview_height
else:
- p.prompt = entry.cond_text
+ p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
@@ -319,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
@@ -328,15 +329,22 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}'.format(embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
image.save(last_saved_image)
@@ -348,7 +356,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(entry.cond_text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/ui.py b/modules/ui.py
index 1bc919c7..b867d40f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -7,6 +7,7 @@ import mimetypes
import os
import random
import sys
+import tempfile
import time
import traceback
import platform
@@ -80,6 +81,8 @@ art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
+save_style_symbol = '\U0001f4be' # 💾
+apply_style_symbol = '\U0001f4cb' # 📋
def plaintext_to_html(text):
@@ -88,6 +91,14 @@ def plaintext_to_html(text):
def image_from_url_text(filedata):
+ if type(filedata) == dict and filedata["is_file"]:
+ filename = filedata["name"]
+ tempdir = os.path.normpath(tempfile.gettempdir())
+ normfn = os.path.normpath(filename)
+ assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
+
+ return Image.open(filename)
+
if type(filedata) == list:
if len(filedata) == 0:
return None
@@ -143,10 +154,7 @@ def save_files(js_data, images, do_make_zip, index):
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
for image_index, filedata in enumerate(images, start_index):
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
+ image = image_from_url_text(filedata)
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
@@ -176,6 +184,23 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+def save_pil_to_file(pil_image, dir=None):
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
@@ -304,7 +329,7 @@ def visit(x, func, path=""):
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
- return [gr_show(), gr_show()]
+ return [gr_show() for x in range(4)]
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
shared.prompt_styles.styles[style.name] = style
@@ -429,29 +454,38 @@ def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id="toprow"):
- with gr.Column(scale=4):
+ with gr.Column(scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
- with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- with gr.Column(scale=10, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
with gr.Row():
- with gr.Column(scale=8):
+ with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
- with gr.Column(scale=1, elem_id="roll_col"):
- sh = gr.Button(elem_id="sh", visible=True)
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ with gr.Column(scale=1, elem_id="roll_col"):
+ roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+
+ if cmd_opts.deepdanbooru:
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
with gr.Row():
@@ -471,20 +505,14 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row(scale=1):
- if is_img2img:
- interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- if cmd_opts.deepdanbooru:
- deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- else:
- deepbooru = None
- else:
- interrogate = None
- deepbooru = None
- prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
- save_style = gr.Button('Create style', elem_id="style_create")
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -588,7 +616,7 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
send_to_img2img = gr.Button('Send to img2img')
@@ -744,10 +772,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
@@ -803,7 +831,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img')
@@ -1166,6 +1194,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1244,6 +1273,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_embedding_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
training_width,
@@ -1268,6 +1298,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_hypernetwork_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
steps,