From fddb4883f4a408b3464076465e1b0949ebe0fc30 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Wed, 26 Oct 2022 22:33:45 +0800
Subject: prototype progress api
---
modules/api/api.py | 89 +++++++++++++++++++++++++++++++++++++++++++++---------
modules/shared.py | 13 ++++++++
2 files changed, 88 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6e9d6097..c038f674 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,8 +1,11 @@
+import time
+
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
+from modules import devices
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters: Json
info: Json
+class ProgressResponse(BaseModel):
+ progress: float
+ eta_relative: float
+ state: Json
+
+# copy from wrap_gradio_gpu_call of webui.py
+# because queue lock will be acquired in api handlers
+# and time start needs to be set
+# the function has been modified into two parts
+
+def before_gpu_call():
+ devices.torch_gc()
+
+ shared.state.sampling_step = 0
+ shared.state.job_count = -1
+ shared.state.job_no = 0
+ shared.state.job_timestamp = shared.state.get_job_timestamp()
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.current_image_sampling_step = 0
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+
+
+def after_gpu_call():
+ shared.state.job = ""
+ shared.state.job_count = 0
+
+ devices.torch_gc()
class Api:
def __init__(self, app, queue_lock):
@@ -33,6 +67,7 @@ class Api:
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
@@ -44,12 +79,12 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -57,9 +92,11 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = []
for i in processed.images:
buffer = io.BytesIO()
@@ -67,30 +104,30 @@ class Api:
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
-
-
+
+
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
+ raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
- raise HTTPException(status_code=404, detail="Init image not found")
+ raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
-
+
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True,
+ "do_not_save_grid": True,
"mask": mask
}
)
@@ -103,9 +140,11 @@ class Api:
p.init_images = imgs
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = []
for i in processed.images:
buffer = io.BytesIO()
@@ -118,6 +157,28 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
+ def progressapi(self):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
+
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/shared.py b/modules/shared.py
index 1a9b8289..00f61898 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -146,6 +146,19 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def js(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.skipped,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return json.dumps(obj)
+
state = State()
--
cgit v1.2.1
From 0dd8480281ffa3e58439a3ce059c02d9f3baa5c7 Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 26 Oct 2022 11:08:44 -0400
Subject: fix: Correct before image saved callback
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..cedbe7bd 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -69,7 +69,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callbacks_before_image_saved:
try:
c.callback(params)
except Exception:
--
cgit v1.2.1
From 9e465c8aa5616df4c6723bee007ffd3910404f12 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:03:34 -0700
Subject: Add strength to textinfo.
---
modules/processing.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 4efba946..93066522 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -329,6 +329,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.1
From d4a069a23cb19104b4e58a33d0d1670fadaefb7a Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:16:27 -0700
Subject: Read hypernet strength from PNG info.
---
modules/ui.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..62a2f4f3 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1812,6 +1812,7 @@ Requested path was: {f}
settings_map = {
'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernetwork strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
--
cgit v1.2.1
From c0677b33161f04c3ed1a7a78f4c7288fb95787b7 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:31:45 -0700
Subject: Explicitly state when Hypernet is none.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 93066522..74a0cd64 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -328,7 +328,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
--
cgit v1.2.1
From db5a354c489bfd1c95e0bbf9af12ab8b5d6fe170 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 28 Oct 2022 01:41:57 -0700
Subject: Always ignore "None.pt" in the hypernet directory.
---
modules/hypernetworks/hypernetwork.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8113b35b..cd920df5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -208,13 +208,16 @@ def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
+ # Prevent a hypothetical "None.pt" from being listed.
+ if name != "None":
+ res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
- if path is not None:
+ # Prevent any file named "None.pt" from being loaded.
+ if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
--
cgit v1.2.1
From 2c4d20388425a5e40b93eef3722e42e8d375fbb4 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Sat, 29 Oct 2022 00:36:51 -0700
Subject: Revert "Explicitly state when Hypernet is none."
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 377c0978..04fdda7c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -395,7 +395,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
--
cgit v1.2.1
From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:37:24 +0700
Subject: Allow trailing comma in learning rate
---
modules/textual_inversion/learn_schedule.py | 33 +++++++++++++++++------------
1 file changed, 20 insertions(+), 13 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 3a736065..76e611b6 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception("Invalid learning rate schedule")
+
def __iter__(self):
return self
--
cgit v1.2.1
From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:42:51 +0700
Subject: Improve lr schedule error message
---
modules/textual_inversion/learn_schedule.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 76e611b6..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,7 +4,7 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
@@ -33,7 +33,7 @@ class LearnScheduleIterator:
return
assert self.rates
except (ValueError, AssertionError):
- raise Exception("Invalid learning rate schedule")
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
def __iter__(self):
--
cgit v1.2.1
From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 18:09:17 +0700
Subject: Add input validations before loading dataset for training
---
modules/hypernetworks/hypernetwork.py | 38 +++++++++++---------
modules/textual_inversion/textual_inversion.py | 48 +++++++++++++++++++-------
2 files changed, 58 insertions(+), 28 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2e84583b..38f35c58 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
- assert hypernetwork_name, 'hypernetwork not selected'
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
images_dir = None
+ hypernetwork = shared.loaded_hypernetwork
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
-
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
+ last_saved_file = ""
+ last_saved_image = ""
+ forced_filename = ""
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
+
cond_model = shared.sd_model.cond_stage_model
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return embedding, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
- hijack = sd_hijack.model_hijack
-
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,))
@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
- ititial_step = embedding.step or 0
- if ititial_step > steps:
- return embedding, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
--
cgit v1.2.1
From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:43:21 +0700
Subject: Add cleanup after training
---
modules/hypernetworks/hypernetwork.py | 201 +++++++++++++------------
modules/textual_inversion/textual_inversion.py | 185 ++++++++++++-----------
2 files changed, 200 insertions(+), 186 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 38f35c58..170d5ea4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = ""
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entries in pbar:
- hypernetwork.step = i + ititial_step
- if len(loss_dict) > 0:
- previous_mean_losses = [i[-1] for i in loss_dict.values()]
- previous_mean_loss = mean(previous_mean_losses)
-
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
- del c
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
- for entry in entries:
- loss_dict[entry.filename].append(loss.item())
-
- optimizer.zero_grad()
- weights[0].grad = None
- loss.backward()
- if weights[0].grad is None:
- steps_without_grad += 1
+ try:
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+ if len(loss_dict) > 0:
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
+
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+ for entry in entries:
+ loss_dict[entry.filename].append(loss.item())
+
+ optimizer.zero_grad()
+ weights[0].grad = None
+ loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
+ optimizer.step()
+
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ raise RuntimeError("Loss diverged.")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
else:
- steps_without_grad = 0
- assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
-
- optimizer.step()
-
- steps_done = hypernetwork.step + 1
-
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
- raise RuntimeError("Loss diverged.")
-
- if len(previous_mean_losses) > 1:
- std = stdev(previous_mean_losses)
- else:
- std = 0
- dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
- pbar.set_description(dataset_loss_info)
-
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{previous_mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
-
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
+
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{previous_mean_loss:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
- if image is not None:
- shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = hypernetwork.step
+ shared.state.job_no = hypernetwork.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
+ finally:
+ if weights:
+ for weight in weights:
+ weight.requires_grad = False
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..fd7f0897 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entries in pbar:
- embedding.step = i + ititial_step
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ try:
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
+
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
+ finally:
+ if embedding and embedding.vec is not None:
+ embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.1
From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:44:05 +0700
Subject: Additional assert on dataset
---
modules/textual_inversion/dataset.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 8bb00d27..ad726577 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
--
cgit v1.2.1
From de1dc0d279a877d5d9f512befe30a7d7e5cf3881 Mon Sep 17 00:00:00 2001
From: Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com>
Date: Sat, 29 Oct 2022 15:23:19 +0100
Subject: Add adjust_steps_if_invalid to find next valid step for ddim uniform
sampler
---
modules/sd_samplers.py | 28 +++++++++++++++-------------
1 file changed, 15 insertions(+), 13 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 3670b57d..aca014e8 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,5 +1,6 @@
from collections import namedtuple
import numpy as np
+from math import floor
import torch
import tqdm
from PIL import Image
@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if self.config.name == 'DDIM' and p.ddim_discretize == 'uniform':
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
-
+ steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
- # existing code fails with certain step counts, like 9
- try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
- steps = steps or p.steps
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- # existing code fails with certain step counts, like 9
- try:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
- except Exception:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
--
cgit v1.2.1
From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:32:02 +0700
Subject: Revert "Add cleanup after training"
This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1.
---
modules/hypernetworks/hypernetwork.py | 201 ++++++++++++-------------
modules/textual_inversion/textual_inversion.py | 185 +++++++++++------------
2 files changed, 186 insertions(+), 200 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 170d5ea4..38f35c58 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -398,112 +398,110 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = ""
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
-
- try:
- for i, entries in pbar:
- hypernetwork.step = i + ititial_step
- if len(loss_dict) > 0:
- previous_mean_losses = [i[-1] for i in loss_dict.values()]
- previous_mean_loss = mean(previous_mean_losses)
-
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
- del c
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
- for entry in entries:
- loss_dict[entry.filename].append(loss.item())
-
- optimizer.zero_grad()
- weights[0].grad = None
- loss.backward()
-
- if weights[0].grad is None:
- steps_without_grad += 1
- else:
- steps_without_grad = 0
- assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
-
- optimizer.step()
-
- steps_done = hypernetwork.step + 1
-
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
- raise RuntimeError("Loss diverged.")
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+ if len(loss_dict) > 0:
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
- if len(previous_mean_losses) > 1:
- std = stdev(previous_mean_losses)
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+ for entry in entries:
+ loss_dict[entry.filename].append(loss.item())
+
+ optimizer.zero_grad()
+ weights[0].grad = None
+ loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
else:
- std = 0
- dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
- pbar.set_description(dataset_loss_info)
-
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{previous_mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
-
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
+ optimizer.step()
- preview_text = p.prompt
+ steps_done = hypernetwork.step + 1
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ raise RuntimeError("Loss diverged.")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
+ else:
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
+
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{previous_mean_loss:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- if image is not None:
- shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.job_no = hypernetwork.step
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.textinfo = f"""
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -512,14 +510,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
- finally:
- if weights:
- for weight in weights:
- weight.requires_grad = False
- if unload:
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
+
report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd7f0897..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
- try:
- for i, entries in pbar:
- embedding.step = i + ititial_step
-
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
- finally:
- if embedding and embedding.vec is not None:
- embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.1
From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:49:29 +0700
Subject: Add missing info on hypernetwork/embedding model log
Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513
Also group the saving into one
---
modules/hypernetworks/hypernetwork.py | 31 +++++++++++++-------
modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++---------
2 files changed, 47 insertions(+), 23 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 38f35c58..86daf825 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
images_dir = None
hypernetwork = shared.loaded_hypernetwork
+ checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
@@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
"""
report_statistics(loss_dict)
- checkpoint = sd_models.select_checkpoint()
- hypernetwork.sd_checkpoint = checkpoint.hash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
- hypernetwork.name = hypernetwork_name
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(filename)
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
return hypernetwork, filename
+
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..ee9917ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
@@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step > steps:
@@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
@@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}
"""
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
- embedding.name = embedding_name
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
return embedding, filename
+
+def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
--
cgit v1.2.1
From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:54:59 +0700
Subject: Fix dataset still being loaded even when training will be skipped
---
modules/hypernetworks/hypernetwork.py | 2 +-
modules/textual_inversion/textual_inversion.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 86daf825..07acadc9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -364,7 +364,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
+ if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ee9917ce..e0babb46 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
- if ititial_step > steps:
+ if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
--
cgit v1.2.1
From 4609b83cd496013a05e77c42af031d89f07785a9 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 29 Oct 2022 16:09:19 -0300
Subject: Add PNG Info endpoint
---
modules/api/api.py | 12 +++++++++---
modules/api/models.py | 9 ++++++++-
2 files changed, 17 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 49c213ea..8fcd068d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -5,7 +5,7 @@ import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_extras
+from modules.extras import run_extras, run_pnginfo
def upscaler_to_index(name: str):
try:
@@ -32,6 +32,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -125,8 +126,13 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self):
- raise NotImplementedError
+ def pnginfoapi(self, req:PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+ return PNGInfoResponse(info=result[1])
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index dd122321..58e8e58b 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -148,4 +149,10 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with all the info the image had")
\ No newline at end of file
--
cgit v1.2.1
From 83a1f44ae26cb89492064bb8be0321b14a75efe4 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 29 Oct 2022 16:10:00 -0300
Subject: Fix space
---
modules/api/api.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 8fcd068d..d0f488ca 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -126,7 +126,7 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self, req:PNGInfoRequest):
+ def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()):
return PNGInfoResponse(info="")
--
cgit v1.2.1
From 9bb6b6509aff8c1e6546d5a798ef9e9922758dc4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 22:20:02 +0300
Subject: add postprocess call for scripts
---
modules/processing.py | 12 +++++++++---
modules/scripts.py | 24 +++++++++++++++++++++---
2 files changed, 30 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 548eec29..50343846 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/scripts.py b/modules/scripts.py
index a7f36012..96e44bfd 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -64,7 +64,16 @@ class Script:
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
- scripts. You can modify the processing object (p) here, inject hooks, etc.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
"""
pass
@@ -289,13 +298,22 @@ class ScriptRunner:
return processed
- def run_alwayson_scripts(self, p):
+ def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
--
cgit v1.2.1
From f62db4d5c753bc32d2ae166606ce41f4c5fa5c43 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 03:55:43 +0800
Subject: fix progress response model
---
modules/api/api.py | 30 ------------------------------
modules/api/models.py | 8 ++++----
2 files changed, 4 insertions(+), 34 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index e93cddcb..7e8522a2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,33 +1,3 @@
-# import time
-
-# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
-# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-# from modules.sd_samplers import all_samplers
-# from modules.extras import run_pnginfo
-# import modules.shared as shared
-# from modules import devices
-# import uvicorn
-# from fastapi import Body, APIRouter, HTTPException
-# from fastapi.responses import JSONResponse
-# from pydantic import BaseModel, Field, Json
-# from typing import List
-# import json
-# import io
-# import base64
-# from PIL import Image
-
-# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-
-# class TextToImageResponse(BaseModel):
-# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
-# parameters: Json
-# info: Json
-
-# class ImageToImageResponse(BaseModel):
-# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
-# parameters: Json
-# info: Json
-
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
diff --git a/modules/api/models.py b/modules/api/models.py
index 8d4abc39..e1762fb9 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,6 +1,6 @@
import inspect
from click import prompt
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel, Field, Json, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@@ -158,6 +158,6 @@ class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
class ProgressResponse(BaseModel):
- progress: float
- eta_relative: float
- state: dict
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+ eta_relative: float = Field(title="ETA in secs")
+ state: Json
--
cgit v1.2.1
From e9c6c2a51f972fd7cd88ea740ade4ac3d8108b67 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 04:02:56 +0800
Subject: add description for state field
---
modules/api/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index e1762fb9..709ab5a6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
- state: Json
+ state: Json = Field(title="State", description="The current state snapshot")
--
cgit v1.2.1
From 88f46a5bec610cf03641f18becbe3deda541e982 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 05:04:29 +0800
Subject: update progress response model
---
modules/api/api.py | 6 +++---
modules/api/models.py | 4 ++--
modules/shared.py | 4 ++--
3 files changed, 7 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 7e8522a2..5912d289 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -61,7 +61,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
- self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -171,7 +171,7 @@ class Api:
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
@@ -187,7 +187,7 @@ class Api:
progress = min(progress, 1)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 709ab5a6..0ab85ec5 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,6 +1,6 @@
import inspect
from click import prompt
-from pydantic import BaseModel, Field, Json, create_model
+from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
- state: Json = Field(title="State", description="The current state snapshot")
+ state: dict = Field(title="State", description="The current state snapshot")
diff --git a/modules/shared.py b/modules/shared.py
index 0f4c035d..f7b0990c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -147,7 +147,7 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
- def js(self):
+ def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
@@ -158,7 +158,7 @@ class State:
"sampling_steps": self.sampling_steps,
}
- return json.dumps(obj)
+ return obj
state = State()
--
cgit v1.2.1
From 9f104b53c425e248595e5b6481336d2a339e015e Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 05:19:17 +0800
Subject: preview current image when opts.show_progress_every_n_steps is
enabled
---
modules/api/api.py | 8 ++++++--
modules/api/models.py | 1 +
2 files changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5912d289..e960bb7b 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,7 +1,7 @@
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
@@ -187,7 +187,11 @@ class Api:
progress = min(progress, 1)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
+ current_image = None
+ if shared.state.current_image:
+ current_image = encode_pil_to_base64(shared.state.current_image)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 0ab85ec5..c8bc719a 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -161,3 +161,4 @@ class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
--
cgit v1.2.1
From 66d038f6a41507af2243ff1f6618a745a092c290 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Sat, 29 Oct 2022 15:00:08 -0700
Subject: Read hypernet strength from PNG info.
---
modules/generation_parameters_copypaste.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index bbaad42e..59c6d7da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernetwork strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
--
cgit v1.2.1
From 9f4f894d74b57c3d02ebccaa59f9c22fca2b6c90 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 06:03:32 +0800
Subject: allow skip current image in progress api
---
modules/api/api.py | 4 ++--
modules/api/models.py | 3 +++
2 files changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index e960bb7b..5c5b210f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -167,7 +167,7 @@ class Api:
return PNGInfoResponse(info=result[1])
- def progressapi(self):
+ def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
@@ -188,7 +188,7 @@ class Api:
progress = min(progress, 1)
current_image = None
- if shared.state.current_image:
+ if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
diff --git a/modules/api/models.py b/modules/api/models.py
index c8bc719a..9ee42a17 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -157,6 +157,9 @@ class PNGInfoRequest(BaseModel):
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
+class ProgressRequest(BaseModel):
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
--
cgit v1.2.1
From 05a657dd357eaca6940c4775daa946bd33f1167d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 07:36:56 +0300
Subject: fix broken hires fix
---
modules/processing.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 50343846..947ce6fa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -686,15 +686,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ image_conditioning = self.txt2img_image_conditioning(x)
+
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
--
cgit v1.2.1
From 61836bd544fc8f4ef62f311c9d5964fbdaeb3f4c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 08:48:53 +0300
Subject: shorten Hypernetwork strength in infotext and omit it when it's the
default value.
---
modules/generation_parameters_copypaste.py | 2 +-
modules/processing.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 59c6d7da..df70c728 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -66,7 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernetwork strength',
+ 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
diff --git a/modules/processing.py b/modules/processing.py
index ecaa78e2..b1df4918 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -396,7 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.1
From 149784202cca8612b43629c601ee27cfda64e623 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 09:10:22 +0300
Subject: rework #3722 to not introduce duplicate code
---
modules/api/api.py | 43 +++++++++++++------------------------------
modules/shared.py | 22 +++++++++++++++++++---
2 files changed, 32 insertions(+), 33 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5c5b210f..6c06d449 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-# copy from wrap_gradio_gpu_call of webui.py
-# because queue lock will be acquired in api handlers
-# and time start needs to be set
-# the function has been modified into two parts
-
-def before_gpu_call():
- devices.torch_gc()
-
- shared.state.sampling_step = 0
- shared.state.job_count = -1
- shared.state.job_no = 0
- shared.state.job_timestamp = shared.state.get_job_timestamp()
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.current_image_sampling_step = 0
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.textinfo = None
- shared.state.time_start = time.time()
-
-def after_gpu_call():
- shared.state.job = ""
- shared.state.job_count = 0
-
- devices.torch_gc()
def upscaler_to_index(name: str):
try:
@@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -78,10 +56,13 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -119,11 +100,13 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
- # Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/shared.py b/modules/shared.py
index f7b0990c..e4f163c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -144,9 +144,6 @@ class State:
self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
-
def dict(self):
obj = {
"skipped": self.skipped,
@@ -160,6 +157,25 @@ class State:
return obj
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
state = State()
--
cgit v1.2.1
From 34c86c12b0a9d650d4e7c5be478bca34ad8ed048 Mon Sep 17 00:00:00 2001
From: Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com>
Date: Sun, 30 Oct 2022 11:04:27 +0000
Subject: Include PLMS in adjust steps as it also can fail in the same way
---
modules/sd_samplers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index aca014e8..8772db56 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -208,7 +208,7 @@ class VanillaStableDiffusionSampler:
def adjust_steps_if_invalid(self, p, num_steps):
- if self.config.name == 'DDIM' and p.ddim_discretize == 'uniform':
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
--
cgit v1.2.1
From 423f22228306ae72d0480e25add9777c3c5d8fdf Mon Sep 17 00:00:00 2001
From: Maiko Tan
Date: Sun, 30 Oct 2022 22:46:43 +0800
Subject: feat: add app started callback
---
modules/script_callbacks.py | 15 +++++++++++++++
1 file changed, 15 insertions(+)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..f5509629 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -3,6 +3,8 @@ import traceback
from collections import namedtuple
import inspect
+from fastapi import FastAPI
+from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -25,6 +27,7 @@ class ImageSaveParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -40,6 +43,14 @@ def clear_callbacks():
callbacks_image_saved.clear()
+def app_started_callback(demo: Blocks, app: FastAPI):
+ for c in callbacks_app_started:
+ try:
+ c.callback(demo, app)
+ except Exception:
+ report_exception(c, 'app_started_callback')
+
+
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
try:
@@ -91,6 +102,10 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+def on_app_started(callback):
+ add_callback(callbacks_app_started, callback)
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
--
cgit v1.2.1
From 081df45da47feadfb055552c0ad5c6e6ecdd9f28 Mon Sep 17 00:00:00 2001
From: Maiko Sinkyaet Tan
Date: Mon, 31 Oct 2022 08:47:43 +0800
Subject: docs: add python doc (?)
not sure if this available...
---
modules/script_callbacks.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f5509629..d8aa6f00 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -103,6 +103,8 @@ def add_callback(callbacks, fun):
def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
--
cgit v1.2.1
From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 17:36:45 +0300
Subject: add initial version of the extensions tab fix broken Restart Gradio
button
---
modules/extensions.py | 83 +++++++++++++++
modules/generation_parameters_copypaste.py | 5 +
modules/scripts.py | 21 +---
modules/shared.py | 10 +-
modules/ui.py | 16 ++-
modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++
6 files changed, 274 insertions(+), 23 deletions(-)
create mode 100644 modules/extensions.py
create mode 100644 modules/ui_extensions.py
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 00000000..8d6ae848
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,83 @@
+import os
+import sys
+import traceback
+
+import git
+
+from modules import paths, shared
+
+
+extensions = []
+extensions_dir = os.path.join(paths.script_path, "extensions")
+
+
+def active():
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True):
+ self.name = name
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+
+ repo = None
+ try:
+ if os.path.exists(os.path.join(path, ".git")):
+ repo = git.Repo(path)
+ except Exception:
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
+
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return res
+
+ def check_updates(self):
+ repo = git.Repo(self.path)
+ for fetch in repo.remote().fetch("--dry-run"):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "behind"
+ return
+
+ self.can_update = False
+ self.status = "latest"
+
+ def pull(self):
+ repo = git.Repo(self.path)
+ repo.remotes.origin.pull()
+
+
+def list_extensions():
+ extensions.clear()
+
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ path = os.path.join(extensions_dir, dirname)
+ if not os.path.isdir(path):
+ continue
+
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
+ extensions.append(extension)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index df70c728..985ec95e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = []
+def reset():
+ paste_fields.clear()
+ bind_list.clear()
+
+
def quote(text):
if ',' not in str(text):
return text
diff --git a/modules/scripts.py b/modules/scripts.py
index 96e44bfd..533db45c 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks
+from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object()
@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- for dirname in sorted(os.listdir(extdir)):
- dirpath = os.path.join(extdir, dirname)
- scriptdirpath = os.path.join(dirpath, scriptdirname)
-
- if not os.path.isdir(scriptdirpath):
- continue
-
- for filename in sorted(os.listdir(scriptdirpath)):
- scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename):
res = []
- dirs = [paths.script_path]
-
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
if not os.path.isdir(dirpath):
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..cce87081 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -132,6 +132,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ need_restart = False
def skip(self):
self.skipped = True
@@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable those extensions"),
+}))
+
+options_templates.update()
+
class Options:
data = None
@@ -365,8 +372,9 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
- if key in self.data:
+ if key in self.data or key in self.data_labels:
self.data[key] = value
+ return
return super(Options, self).__setattr__(key, value)
diff --git a/modules/ui.py b/modules/ui.py
index 5055ca64..2c15abb7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
-from modules import sd_hijack, sd_models, localization, script_callbacks
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
+ parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None
with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
- if previous_section != item.section:
+ if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None:
column.__exit__()
@@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
else:
component = create_setting_component(k)
component_dict[k] = component
@@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call):
def request_restart():
shared.state.interrupt()
- settings_interface.gradio_ref.do_restart = True
+ shared.state.need_restart = True
restart_gradio.click(
+
fn=request_restart,
inputs=[],
outputs=[],
@@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- settings_interface.gradio_ref = demo
-
parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 00000000..b7d747dc
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,162 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+
+from modules import extensions, shared, paths
+
+
+def apply_and_restart(disable_list, update_list):
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ update = json.loads(update_list)
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+ update = set(update)
+
+ for ext in extensions.extensions:
+ if ext.name not in update:
+ continue
+
+ try:
+ ext.pull()
+ except Exception:
+ print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+
+def check_updates():
+ for ext in extensions.extensions:
+ if ext.remote is None:
+ continue
+
+ try:
+ ext.check_updates()
+ except Exception:
+ print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return extension_table()
+
+
+def extension_table():
+ code = f"""
+
+ """
+
+ return code
+
+
+def install_extension_from_url(dirname, url):
+ assert url, 'No URL specified'
+
+ if dirname is None or dirname == "":
+ *parts, last_part = url.split('/')
+ last_part = last_part.replace(".git", "")
+
+ dirname = last_part
+
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+ assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
+
+ tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+
+ try:
+ shutil.rmtree(tmpdir, True)
+
+ repo = git.Repo.clone_from(url, tmpdir)
+ repo.remote().fetch()
+
+ os.rename(tmpdir, target_dir)
+
+ extensions.list_extensions()
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+ finally:
+ shutil.rmtree(tmpdir, True)
+
+
+def create_ui():
+ import modules.ui
+
+ with gr.Blocks(analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.TabItem("Installed"):
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
+
+ with gr.Row():
+ apply = gr.Button(value="Apply and restart UI", variant="primary")
+ check = gr.Button(value="Check for updates")
+
+ extensions_table = gr.HTML(lambda: extension_table())
+
+ apply.click(
+ fn=apply_and_restart,
+ _js="extensions_apply",
+ inputs=[extensions_disabled_list, extensions_update_list],
+ outputs=[],
+ )
+
+ check.click(
+ fn=check_updates,
+ _js="extensions_check",
+ inputs=[],
+ outputs=[extensions_table],
+ )
+
+ with gr.TabItem("Install from URL"):
+ install_url = gr.Text(label="URL for extension's git repository")
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+ intall_button = gr.Button(value="Install", variant="primary")
+ intall_result = gr.HTML(elem_id="extension_install_result")
+
+ intall_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ inputs=[install_dirname, install_url],
+ outputs=[extensions_table, intall_result],
+ )
+
+ return ui
--
cgit v1.2.1
From dc7425a56e7a014cbfa3b3d44ad2321e519fe378 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:33:44 +0300
Subject: disable access to extension stuff for non-local servers
---
modules/shared.py | 5 ++++-
modules/ui_extensions.py | 10 ++++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index cce87081..a27c654e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
@@ -97,6 +97,9 @@ restricted_opts = {
"outdir_save",
}
+if cmd_opts.share or cmd_opts.listen:
+ cmd_opts.disable_extension_access = True
+
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index b7d747dc..e74b7d68 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -13,7 +13,13 @@ import html
from modules import extensions, shared, paths
+def check_access():
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
+
+
def apply_and_restart(disable_list, update_list):
+ check_access()
+
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
@@ -40,6 +46,8 @@ def apply_and_restart(disable_list, update_list):
def check_updates():
+ check_access()
+
for ext in extensions.extensions:
if ext.remote is None:
continue
@@ -89,6 +97,8 @@ def extension_table():
def install_extension_from_url(dirname, url):
+ check_access()
+
assert url, 'No URL specified'
if dirname is None or dirname == "":
--
cgit v1.2.1
From 58cc03edd0fe8c7e64297bcfe51111caaafabfd7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:40:47 +0300
Subject: fix scripts I broke with the extension tab changes
---
modules/extensions.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
index 8d6ae848..897af96e 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -46,7 +46,7 @@ class Extension:
res = []
for filename in sorted(os.listdir(dirpath)):
- res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
--
cgit v1.2.1
From 9e22a357545c0395c81dd800c72fa18f350545ec Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:45:50 +0300
Subject: fix the error with extension tab not working because of the previous
commit
---
modules/shared.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index a27c654e..c83fb9f5 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -97,8 +97,7 @@ restricted_opts = {
"outdir_save",
}
-if cmd_opts.share or cmd_opts.listen:
- cmd_opts.disable_extension_access = True
+cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
--
cgit v1.2.1
From 525c1edf431d9c9efc1349be8657f0301299e3bc Mon Sep 17 00:00:00 2001
From: k_sugawara
Date: Tue, 1 Nov 2022 09:40:54 +0900
Subject: make save dir if save dir is not exists
---
modules/img2img.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index efda26e1..35c5df9b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}"
if not save_normally:
+ os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename))
--
cgit v1.2.1
From 5b0f624bdc1335313258f59a37607e699e800c22 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 1 Nov 2022 09:59:00 +0300
Subject: Added Available tab to extensions UI.
---
modules/ui_extensions.py | 112 +++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 104 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index e74b7d68..ab807722 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -13,6 +13,9 @@ import html
from modules import extensions, shared, paths
+available_extensions = {"extensions": []}
+
+
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
@@ -96,6 +99,14 @@ def extension_table():
return code
+def normalize_git_url(url):
+ if url is None:
+ return ""
+
+ url = url.replace(".git", "")
+ return url
+
+
def install_extension_from_url(dirname, url):
check_access()
@@ -103,14 +114,15 @@ def install_extension_from_url(dirname, url):
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
- last_part = last_part.replace(".git", "")
+ last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
- assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
+ normalized_url = normalize_git_url(url)
+ assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
@@ -128,18 +140,80 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
+def install_extension_from_index(url):
+ ext_table, message = install_extension_from_url(None, url)
+
+ return refresh_available_extensions_from_data(), ext_table, message
+
+
+def refresh_available_extensions(url):
+ global available_extensions
+
+ import urllib.request
+ with urllib.request.urlopen(url) as response:
+ text = response.read()
+
+ available_extensions = json.loads(text)
+
+ return url, refresh_available_extensions_from_data(), ''
+
+
+def refresh_available_extensions_from_data():
+ extlist = available_extensions["extensions"]
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+ code = f"""
+
+ """
+
+ return code
+
+
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
- extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
- extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table())
@@ -157,16 +231,38 @@ def create_ui():
outputs=[extensions_table],
)
+ with gr.TabItem("Available"):
+ with gr.Row():
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+ install_result = gr.HTML()
+ available_extensions_table = gr.HTML()
+
+ refresh_available_extensions_button.click(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[available_extensions_index],
+ outputs=[available_extensions_index, available_extensions_table, install_result],
+ )
+
+ install_extension_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[extension_to_install],
+ outputs=[available_extensions_table, extensions_table, install_result],
+ )
+
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
- intall_button = gr.Button(value="Install", variant="primary")
- intall_result = gr.HTML(elem_id="extension_install_result")
+ install_button = gr.Button(value="Install", variant="primary")
+ install_result = gr.HTML(elem_id="extension_install_result")
- intall_button.click(
+ install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
- outputs=[extensions_table, intall_result],
+ outputs=[extensions_table, install_result],
)
return ui
--
cgit v1.2.1
From af758e97fa2c4c853042f121af4e974be01e6696 Mon Sep 17 00:00:00 2001
From: Jairo Correa
Date: Tue, 1 Nov 2022 04:01:49 -0300
Subject: Unload sd_model before loading the other
---
modules/lowvram.py | 21 +++++++++++++--------
modules/processing.py | 3 +++
modules/sd_hijack.py | 4 ++++
modules/sd_models.py | 14 +++++++++++++-
4 files changed, 33 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/lowvram.py b/modules/lowvram.py
index f327c3df..a4652cb1 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
diff --git a/modules/processing.py b/modules/processing.py
index b1df4918..57d3a523 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
+ p.sd_model = None
+ p.sampler = None
+
return res
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 0f10828e..bc49d235 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ self.layers = None
+ self.circular_enabled = False
+ self.clip = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f86dc3ed..90007da3 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,6 +1,7 @@
import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
import re
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
+
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ if not sd_model:
+ sd_model = shared.sd_model
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
--
cgit v1.2.1
From d35bf649456da2558cbb6f2ea16fa1606022b7e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 1 Nov 2022 14:19:24 +0300
Subject: make launch.py run installers for extensions that have ones add some
more classes to safety module for an extension
---
modules/safe.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/safe.py b/modules/safe.py
index 399165a1..348a24fc 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
--
cgit v1.2.1