aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py61
-rw-r--r--modules/hypernetworks/ui.py10
2 files changed, 54 insertions, 17 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index aa701bda..8314450a 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,6 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class HypernetworkModule(torch.nn.Module):
@@ -42,7 +43,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None):
+ def __init__(self, name=None, enable_sizes=None):
self.filename = None
self.name = name
self.layers = {}
@@ -50,7 +51,7 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
- for size in [320, 640, 768, 1280]:
+ for size in enable_sizes or []:
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
def weights(self):
@@ -119,6 +120,17 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
@@ -163,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
- assert hypernetwork_name, 'embedding not selected'
+ assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -175,6 +187,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
@@ -188,19 +201,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
else:
images_dir = None
- cond_model = shared.sd_model.cond_stage_model
-
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -210,22 +223,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > steps:
- break
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
-
+ cond = cond.to(devices.device)
x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
+ del cond
losses[hypernetwork.step % losses.shape[0]] = loss.item()
@@ -244,6 +269,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=preview_text,
@@ -255,6 +284,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
processed = processing.process_images(p)
image = processed.images[0]
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
shared.state.current_image = image
image.save(last_saved_image)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index e7540f41..dfa599af 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -5,15 +5,15 @@ import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared
+from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name):
+def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn)
shared.reload_hypernetworks()
@@ -25,6 +25,8 @@ def train_hypernetwork(*args):
initial_hypernetwork = shared.loaded_hypernetwork
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
try:
sd_hijack.undo_optimizations()
@@ -39,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)}
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()