aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBilly Cao <aliencaocao@gmail.com>2022-11-23 18:11:24 +0800
committerBilly Cao <aliencaocao@gmail.com>2022-11-23 18:11:24 +0800
commitadb6cb7619989cbc7a271cc6c2ae27bb936c43d9 (patch)
tree164da7276d0dcb00d3f6871c9099604a05151277
parent828438b4a190759807f9054932cae3a8b880ddf1 (diff)
Patch UNet Forward to support resolutions that are not multiples of 64
Also modifed the UI to no longer step in 64
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_hijack_optimizations.py31
-rw-r--r--modules/ui.py24
3 files changed, 45 insertions, 12 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index eaedac13..6141f705 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -16,6 +16,7 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+import ldm.modules.diffusionmodules.openaimodel
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -26,6 +27,7 @@ def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 98123fbf..8cd4c954 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -5,6 +5,7 @@ import importlib
import torch
from torch import einsum
+import torch.nn.functional as F
from ldm.util import default
from einops import rearrange
@@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
+from ldm.modules.diffusionmodules.util import timestep_embedding
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ if h.shape[-2:] != hs[-1].shape[-2:]:
+ h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/modules/ui.py b/modules/ui.py
index e6da1b2a..85e531af 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -380,8 +380,8 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from width", value=0)
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from height", value=0)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
+ height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass width", value=0)
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True):
@@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
- process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ process_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
@@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ training_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)