aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/PULL_REQUEST_TEMPLATE/pull_request_template.md28
-rw-r--r--modules/devices.py6
-rw-r--r--modules/processing.py14
-rw-r--r--modules/safe.py6
-rw-r--r--modules/sd_models.py3
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py1
-rw-r--r--modules/textual_inversion/dataset.py3
-rw-r--r--modules/textual_inversion/preprocess.py19
-rw-r--r--modules/textual_inversion/textual_inversion.py14
-rw-r--r--modules/ui.py12
-rw-r--r--style.css4
12 files changed, 91 insertions, 23 deletions
diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
new file mode 100644
index 00000000..86009613
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
@@ -0,0 +1,28 @@
+# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
+
+If you have a large change, pay special attention to this paragraph:
+
+> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
+
+Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
+
+**Describe what this pull request is trying to achieve.**
+
+A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
+
+**Additional notes and description of your changes**
+
+More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
+
+**Environment this was tested in**
+
+List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
+ - OS: [e.g. Windows, Linux]
+ - Browser [e.g. chrome, safari]
+ - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
+
+**Screenshots or videos of your changes**
+
+If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
+
+This is **required** for anything that touches the user interface. \ No newline at end of file
diff --git a/modules/devices.py b/modules/devices.py
index 0158b11f..03ef58f1 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
+dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@@ -59,9 +60,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
-def autocast():
+def autocast(disable=False):
from modules import shared
+ if disable:
+ return contextlib.nullcontext()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/processing.py b/modules/processing.py
index 94d2dd62..50ba4fc5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -398,9 +405,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
- samples_ddim = samples_ddim.to(devices.dtype)
-
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@@ -533,7 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
diff --git a/modules/safe.py b/modules/safe.py
index 4d06f2a5..05917463 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -12,6 +12,10 @@ import _codecs
import zipfile
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
+
+
def encode(*args):
out = _codecs.encode(*args)
return out
@@ -20,7 +24,7 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return torch.storage._TypedStorage()
+ return TypedStorage()
def find_class(self, module, name):
if module == 'collections' and name == 'OrderedDict':
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e63d3c29..2cdcd84f 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info):
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if os.path.exists(vae_file):
@@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict)
+ model.first_stage_model.to(devices.dtype_vae)
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6e743f7e..d168b938 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser
+from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples):
- x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
diff --git a/modules/shared.py b/modules/shared.py
index 1995a99a..5dfc344c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..bcf772d2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,11 +15,10 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..d7efdef2 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,8 +7,9 @@ import tqdm
from modules import shared, images
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9a18ee5c..7a24192e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -190,7 +190,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -222,7 +222,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -240,6 +240,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps:
return embedding, filename
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ epoch_len = (tr_img_len * num_repeats) + tr_img_len
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -263,7 +266,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = embedding.step // epoch_len
+ epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -276,6 +282,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
sd_model=shared.sd_model,
prompt=text,
steps=20,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index 202c4866..0f6427a6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1029,6 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1043,13 +1045,16 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
+ num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@@ -1093,6 +1098,8 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
+ process_width,
+ process_height,
process_flip,
process_split,
process_caption,
@@ -1111,7 +1118,10 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
+ training_width,
+ training_height,
steps,
+ num_repeats,
create_image_every,
save_embedding_every,
template_file,
diff --git a/style.css b/style.css
index c0c3f2bb..04bb9576 100644
--- a/style.css
+++ b/style.css
@@ -1,3 +1,7 @@
+.container {
+ max-width: 100%;
+}
+
.output-html p {margin: 0 0.5em;}
.row > *,