aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md32
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.md20
-rw-r--r--README.md11
-rw-r--r--modules/codeformer_model.py32
-rw-r--r--modules/devices.py17
-rw-r--r--modules/errors.py10
-rw-r--r--modules/extras.py8
-rw-r--r--modules/gfpgan_model.py15
-rw-r--r--modules/images.py43
-rw-r--r--modules/img2img.py7
-rw-r--r--modules/interrogate.py63
-rw-r--r--modules/lowvram.py10
-rw-r--r--modules/processing.py24
-rw-r--r--modules/sd_hijack.py7
-rw-r--r--modules/sd_samplers.py8
-rw-r--r--modules/shared.py33
-rw-r--r--modules/ui.py27
-rw-r--r--scripts/img2imgalt.py104
-rw-r--r--scripts/poor_mans_outpainting.py6
-rw-r--r--scripts/prompt_matrix.py2
-rw-r--r--scripts/prompts_from_file.py42
-rw-r--r--scripts/xy_grid.py2
-rw-r--r--webui.bat6
-rw-r--r--webui.py11
24 files changed, 440 insertions, 100 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 00000000..21accbf0
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,32 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: bug
+assignees: ''
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Steps to reproduce the behavior:
+1. Go to '...'
+2. Click on '....'
+3. Scroll down to '....'
+4. See error
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Screenshots**
+If applicable, add screenshots to help explain your problem.
+
+**Desktop (please complete the following information):**
+ - OS: [e.g. Windows, Linux]
+ - Browser [e.g. chrome, safari]
+ - Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 00000000..bbcbbe7d
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
+
+**Describe the solution you'd like**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
diff --git a/README.md b/README.md
index 04de29ca..9bde1f2a 100644
--- a/README.md
+++ b/README.md
@@ -283,6 +283,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pt
After that follow the instructions in the `Manual instructions` section starting at step `:: clone repositories for Stable Diffusion and (optionally) CodeFormer`.
+### img2img alterantive test
+- see [this post](https://www.reddit.com/r/StableDiffusion/comments/xboy90/a_better_way_of_doing_img2img_by_finding_the/) on ebaumsworld.com for context.
+- find it in scripts section
+- put description of input image into the Original prompt field
+- use Euler only
+- recommended: 50 steps, low cfg scale between 1 and 2
+- denoising and seed don't matter
+- decode cfg scale between 0 and 1
+- decode steps 50
+- original blue haired woman close nearly reproduces with cfg scale=1.8
+
## Credits
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index c638cb4d..6cd29c83 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -5,7 +5,7 @@ import traceback
import cv2
import torch
-from modules import shared
+from modules import shared, devices
from modules.paths import script_path
import modules.shared
import modules.face_restoration
@@ -53,6 +53,7 @@ def setup_codeformer():
def create_models(self):
if self.net is not None and self.face_helper is not None:
+ self.net.to(shared.device)
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device_codeformer)
@@ -63,9 +64,9 @@ def setup_codeformer():
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device_codeformer)
- if not cmd_opts.unload_gfpgan:
- self.net = net
- self.face_helper = face_helper
+ self.net = net
+ self.face_helper = face_helper
+ self.net.to(shared.device)
return net, face_helper
@@ -74,20 +75,20 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2]
- net, face_helper = self.create_models()
- face_helper.clean_all()
- face_helper.read_image(np_image)
- face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- face_helper.align_warp_face()
+ self.create_models()
+ self.face_helper.clean_all()
+ self.face_helper.read_image(np_image)
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ self.face_helper.align_warp_face()
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device_codeformer)
try:
with torch.no_grad():
- output = net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
@@ -96,16 +97,19 @@ def setup_codeformer():
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
- face_helper.add_restored_face(restored_face)
+ self.face_helper.add_restored_face(restored_face)
- face_helper.get_inverse_affine(None)
+ self.face_helper.get_inverse_affine(None)
- restored_img = face_helper.paste_faces_to_input_image()
+ restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ if shared.opts.face_restoration_unload:
+ self.net.to(devices.cpu)
+
return restored_img
global have_codeformer
diff --git a/modules/devices.py b/modules/devices.py
index 30d30b99..a93a245b 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,6 +1,8 @@
import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
+from modules import errors
+
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
@@ -14,3 +16,18 @@ def get_optimal_device():
return torch.device("mps")
return cpu
+
+
+def torch_gc():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+def enable_tf32():
+ if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+
+errors.run(enable_tf32, "Enabling TF32")
diff --git a/modules/errors.py b/modules/errors.py
new file mode 100644
index 00000000..372dc51a
--- /dev/null
+++ b/modules/errors.py
@@ -0,0 +1,10 @@
+import sys
+import traceback
+
+
+def run(code, task):
+ try:
+ code()
+ except Exception as e:
+ print(f"{task}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/extras.py b/modules/extras.py
index 6aeae6cb..596cd172 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,7 +1,7 @@
import numpy as np
from PIL import Image
-from modules import processing, shared, images
+from modules import processing, shared, images, devices
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -11,7 +11,9 @@ cached_images = {}
def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
- processing.torch_gc()
+ devices.torch_gc()
+
+ existing_pnginfo = image.info or {}
image = image.convert("RGB")
info = ""
@@ -65,7 +67,7 @@ def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weigh
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
- images.save_image(image, outpath, "", None, info=info, extension=opts.samples_format, short_filename=True, no_prompt=True, pnginfo_section_name="extras")
+ images.save_image(image, outpath, "", None, info=info, extension=opts.samples_format, short_filename=True, no_prompt=True, pnginfo_section_name="extras", existing_info=existing_pnginfo)
return image, plaintext_to_html(info), ''
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index f697326c..0af97123 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -2,7 +2,7 @@ import os
import sys
import traceback
-from modules import shared
+from modules import shared, devices
from modules.shared import cmd_opts
from modules.paths import script_path
import modules.face_restoration
@@ -28,24 +28,29 @@ def gfpgan():
global loaded_gfpgan_model
if loaded_gfpgan_model is not None:
+ loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model
if gfpgan_constructor is None:
return None
model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
-
- if not cmd_opts.unload_gfpgan:
- loaded_gfpgan_model = model
+ model.gfpgan.to(shared.device)
+ loaded_gfpgan_model = model
return model
def gfpgan_fix_faces(np_image):
+ model = gfpgan()
+
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
+ if shared.opts.face_restoration_unload:
+ model.gfpgan.to(devices.cpu)
+
return np_image
diff --git a/modules/images.py b/modules/images.py
index 26c399b6..d742ed98 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -9,6 +9,7 @@ from fonts.ttf import Roboto
import string
import modules.shared
+from modules import sd_samplers
from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -134,7 +135,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
fontsize = (width + height) // 25
line_spacing = fontsize // 2
- fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
+
+ try:
+ fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
+ except Exception:
+ fnt = ImageFont.truetype(Roboto, fontsize)
+
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
@@ -239,23 +245,46 @@ invalid_filename_chars = '<>:"/\\|?*\n'
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
-def sanitize_filename_part(text):
- return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
+def sanitize_filename_part(text, replace_spaces=True):
+ if replace_spaces:
+ text = text.replace(' ', '_')
+
+ return text.translate({ord(x): '' for x in invalid_filename_chars})[:128]
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters'):
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters', p=None, existing_info=None):
# would be better to add this as an argument in future, but will do for now
is_a_grid = basename != ""
if short_filename or prompt is None or seed is None:
file_decoration = ""
elif opts.save_to_dirs:
- file_decoration = f"-{seed}"
+ file_decoration = opts.samples_filename_format or "[seed]"
else:
- file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}"
+ file_decoration = opts.samples_filename_format or "[seed]-[prompt_spaces]"
+
+ if file_decoration != "":
+ file_decoration = "-" + file_decoration.lower()
+
+ if seed is not None:
+ file_decoration = file_decoration.replace("[seed]", str(seed))
+ if prompt is not None:
+ file_decoration = file_decoration.replace("[prompt]", sanitize_filename_part(prompt)[:128])
+ file_decoration = file_decoration.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False)[:128])
+ if p is not None:
+ file_decoration = file_decoration.replace("[steps]", str(p.steps))
+ file_decoration = file_decoration.replace("[cfg]", str(p.cfg_scale))
+ file_decoration = file_decoration.replace("[width]", str(p.width))
+ file_decoration = file_decoration.replace("[height]", str(p.height))
+ file_decoration = file_decoration.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)
if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
+
+ if existing_info is not None:
+ for k, v in existing_info.items():
+ pnginfo.add_text(k, v)
+
pnginfo.add_text(pnginfo_section_name, info)
else:
pnginfo = None
@@ -315,7 +344,7 @@ class Upscaler:
img = self.do_upscale(img)
if img.width != w or img.height != h:
- img = img.resize((w, h), resample=LANCZOS)
+ img = img.resize((int(w), int(h)), resample=LANCZOS)
return img
diff --git a/modules/img2img.py b/modules/img2img.py
index 779f620d..15e35093 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,6 +3,7 @@ import cv2
import numpy as np
from PIL import Image, ImageOps, ImageChops
+from modules import devices
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -118,7 +119,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
grid = images.image_grid(history, batch_size, rows=1)
- images.save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
+ images.save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, p=p)
processed = Processed(p, history, initial_seed, initial_info)
@@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
upscaler = shared.sd_upscalers[upscaler_index]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
- processing.torch_gc()
+ devices.torch_gc()
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
@@ -179,7 +180,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
result_images.append(combined_image)
if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.grid_format, info=initial_info)
+ images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.samples_format, info=initial_info, p=p)
processed = Processed(p, result_images, seed, initial_info)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index ed97a58b..06862fcc 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,3 +1,4 @@
+import contextlib
import os
import sys
import traceback
@@ -6,12 +7,11 @@ import re
import torch
-from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths
+from modules import devices, paths, lowvram
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
@@ -26,6 +26,7 @@ class InterrogateModels:
clip_model = None
clip_preprocess = None
categories = None
+ dtype = None
def __init__(self, content_dir):
self.categories = []
@@ -60,30 +61,45 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
+ if not shared.cmd_opts.no_half:
+ self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
+ if not shared.cmd_opts.no_half:
+ self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device)
- def unload(self):
+ self.dtype = next(self.clip_model.parameters()).dtype
+
+ def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)
+ def send_blip_to_ram(self):
+ if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)
+ def unload(self):
+ self.send_clip_to_ram()
+ self.send_blip_to_ram()
+
+ devices.torch_gc()
def rank(self, image_features, text_array, top_count=1):
import clip
+ if shared.opts.interrogate_clip_dict_limit != 0:
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
+
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array]).cuda()
- with torch.no_grad():
- text_features = self.clip_model.encode_text(text_tokens).float()
+ text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device)
@@ -94,13 +110,12 @@ class InterrogateModels:
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
-
def generate_caption(self, pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@@ -111,31 +126,41 @@ class InterrogateModels:
res = None
try:
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ devices.torch_gc()
+
self.load()
caption = self.generate_caption(pil_image)
+ self.send_blip_to_ram()
+ devices.torch_gc()
+
res = caption
- images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device)
+ cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
- with torch.no_grad():
- image_features = self.clip_model.encode_image(images).float()
+ precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
+ with torch.no_grad(), precision_scope("cuda"):
+ image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
- image_features /= image_features.norm(dim=-1, keepdim=True)
+ image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
+ if shared.opts.interrogate_use_builtin_artists:
+ artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
- res += ", " + artist[0]
+ res += ", " + artist[0]
- for name, topn, items in self.categories:
- matches = self.rank(image_features, items, top_count=topn)
- for match, score in matches:
- res += ", " + match
+ for name, topn, items in self.categories:
+ matches = self.rank(image_features, items, top_count=topn)
+ for match, score in matches:
+ res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ res += "<error>"
self.unload()
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 079386c3..7eba1349 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -5,6 +5,16 @@ module_in_gpu = None
cpu = torch.device("cpu")
device = gpu = get_optimal_device()
+
+def send_everything_to_cpu():
+ global module_in_gpu
+
+ if module_in_gpu is not None:
+ module_in_gpu.to(cpu)
+
+ module_in_gpu = None
+
+
def setup_for_low_vram(sd_model, use_medvram):
parents = {}
diff --git a/modules/processing.py b/modules/processing.py
index 542d1136..1e6745cc 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps
import random
import modules.sd_hijack
+from modules import devices
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -23,11 +24,6 @@ opt_C = 4
opt_f = 8
-def torch_gc():
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
-
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
@@ -69,6 +65,7 @@ class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
self.images = images_list
self.prompt = p.prompt
+ self.negative_prompt = p.negative_prompt
self.seed = seed
self.info = info
self.width = p.width
@@ -80,6 +77,7 @@ class Processed:
def js(self):
obj = {
"prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
+ "negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0],
"seed": int(self.seed if type(self.seed) != list else self.seed[0]),
"width": self.width,
"height": self.height,
@@ -174,7 +172,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert p.prompt is not None
- torch_gc()
+ devices.torch_gc()
fix_seed(p)
@@ -195,7 +193,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if type(p.seed) == list:
all_seeds = p.seed
else:
- all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]
+ all_seeds = [int(p.seed + (x if p.subseed_strength == 0 else 0)) for x in range(len(all_prompts))]
if type(p.subseed) == list:
all_subseeds = p.subseed
@@ -275,12 +273,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = x_sample.astype(np.uint8)
if p.restore_faces:
- torch_gc()
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+
+ devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
image = Image.fromarray(x_sample)
+
if p.overlay_images is not None and i < len(p.overlay_images):
overlay = p.overlay_images[i]
@@ -296,7 +298,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
output_images.append(image)
@@ -312,9 +314,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
output_images.insert(0, grid)
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
- torch_gc()
+ devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext())
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9eb6cc20..c058ac6e 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
- mem_required = tensor_size * 2.5
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
@@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
- s2 = s1.softmax(dim=-1)
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6b7979e2..95d24299 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -58,12 +58,14 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec)
+ res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
+ if sampler_wrapper.mask is not None:
+ store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
else:
- store_latent(x_dec)
+ store_latent(res[1])
- return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
+ return res
def extended_tdqm(sequence, *args, desc=None, **kwargs):
diff --git a/modules/shared.py b/modules/shared.py
index 5312768b..9002141a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,8 +13,6 @@ from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate
-config_filename = "config.json"
-
sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
@@ -32,7 +30,7 @@ parser.add_argument("--allow-code", action='store_true', help="allow custom scri
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed if you use --lowvram")
-parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN every time after processing images. Warning: seems to cause memory leaks")
+parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
@@ -42,6 +40,11 @@ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
+parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
+parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
+parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
cmd_opts = parser.parse_args()
@@ -52,6 +55,7 @@ device_seed_type = device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+config_filename = cmd_opts.ui_settings_file
class State:
interrupted = False
@@ -93,18 +97,20 @@ class Options:
self.component_args = component_args
data = None
+ hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
data_labels = {
- "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to two directories below"),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images'),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images'),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab'),
- "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below"),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids'),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids'),
+ "samples_filename_format": OptionInfo("", "Samples filename format using following tags: [steps],[cfg],[prompt],[prompt_spaces],[width],[height],[sampler],[seed]. Leave blank for default."),
+ "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to two directories below", component_args=hide_dirs),
+ "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
+ "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
+ "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
+ "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
+ "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
+ "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"save_to_dirs": OptionInfo(False, "When writing images, create a directory with name derived from the prompt"),
"grid_save_to_dirs": OptionInfo(False, "When writing grids, create a directory with name derived from the prompt"),
"save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button"),
+ "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for individual samples'),
"grid_save": OptionInfo(True, "Save image grids"),
@@ -128,11 +134,14 @@ class Options:
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "interrogate_keep_models_in_memory": OptionInfo(True, "Interrogate: keep models in VRAM"),
+ "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
}
def __init__(self):
diff --git a/modules/ui.py b/modules/ui.py
index b9af2c86..535afaeb 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -93,7 +93,7 @@ def save_files(js_data, images):
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename"])
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000))
for i, filedata in enumerate(images):
@@ -108,7 +108,7 @@ def save_files(js_data, images):
filenames.append(filename)
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0]])
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
@@ -270,7 +270,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
- cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
with gr.Group():
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
@@ -384,8 +384,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False, image_mode="RGBA")
- init_img_with_mask_comment = gr.HTML(elem_id="mask_bug_info", value="<small>if the editor shows ERROR, switch to another tab and back, then to another img2img mode above and back</small>", visible=False)
init_mask = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False)
+ init_img_with_mask_comment = gr.HTML(elem_id="mask_bug_info", value="<small>if the editor shows ERROR, switch to another tab and back, then to another img2img mode above and back</small>", visible=False)
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
@@ -413,7 +413,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
- cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False)
@@ -661,19 +661,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
info = opts.data_labels[key]
t = type(info.default)
+ args = info.component_args() if callable(info.component_args) else info.component_args
+
if info.component is not None:
- args = info.component_args() if callable(info.component_args) else info.component_args
- item = info.component(label=info.label, value=fun, **(args or {}))
+ comp = info.component
elif t == str:
- item = gr.Textbox(label=info.label, value=fun, lines=1)
+ comp = gr.Textbox
elif t == int:
- item = gr.Number(label=info.label, value=fun)
+ comp = gr.Number
elif t == bool:
- item = gr.Checkbox(label=info.label, value=fun)
+ comp = gr.Checkbox
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return item
+ return comp(label=info.label, value=fun, **(args or {}))
components = []
keys = list(opts.data_labels.keys())
@@ -684,6 +685,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
up = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ continue
+
opts.data[key] = value
up.append(comp.update(value=value))
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
new file mode 100644
index 00000000..16a2fdf6
--- /dev/null
+++ b/scripts/img2imgalt.py
@@ -0,0 +1,104 @@
+import numpy as np
+from tqdm import trange
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import processing, shared, sd_samplers
+from modules.processing import Processed
+from modules.sd_samplers import samplers
+from modules.shared import opts, cmd_opts, state
+
+import torch
+import k_diffusion as K
+
+from PIL import Image
+from torch import autocast
+from einops import rearrange, repeat
+
+
+def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
+ x = p.init_latent
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(shared.sd_model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ shared.state.sampling_steps = steps
+
+ for i in trange(1, len(sigmas)):
+ shared.state.sampling_step += 1
+
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
+
+ d = (x - denoised) / sigmas[i]
+ dt = sigmas[i] - sigmas[i - 1]
+
+ x = x + d * dt
+
+ sd_samplers.store_latent(x)
+
+ # This shouldn't be necessary, but solved some VRAM issues
+ del x_in, sigma_in, cond_in, c_out, c_in, t,
+ del eps, denoised_uncond, denoised_cond, denoised, d, dt
+
+ shared.state.nextjob()
+
+ return x / x.std()
+
+cache = [None, None, None, None, None]
+
+class Script(scripts.Script):
+ def title(self):
+ return "img2img alternative test"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ original_prompt = gr.Textbox(label="Original prompt", lines=1)
+ cfg = gr.Slider(label="Decode CFG scale", minimum=0.1, maximum=3.0, step=0.1, value=1.0)
+ st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
+
+ return [original_prompt, cfg, st]
+
+ def run(self, p, original_prompt, cfg, st):
+ p.batch_size = 1
+ p.batch_count = 1
+
+ def sample_extra(x, conditioning, unconditional_conditioning):
+ lat = tuple([int(x*10) for x in p.init_latent.cpu().numpy().flatten().tolist()])
+
+ if cache[0] is not None and cache[1] == cfg and cache[2] == st and len(cache[3]) == len(lat) and sum(np.array(cache[3])-np.array(lat)) < 100 and cache[4] == original_prompt:
+ noise = cache[0]
+ else:
+ shared.state.job_count += 1
+ cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
+ noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st)
+ cache[0] = noise
+ cache[1] = cfg
+ cache[2] = st
+ cache[3] = lat
+ cache[4] = original_prompt
+
+ sampler = samplers[p.sampler_index].constructor(p.sd_model)
+
+ samples_ddim = sampler.sample(p, noise, conditioning, unconditional_conditioning)
+ return samples_ddim
+
+ p.sample = sample_extra
+
+ processed = processing.process_images(p)
+
+ return processed
+
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index c029c67f..b0469110 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -4,7 +4,7 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing
+from modules import images, processing, devices
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
@@ -77,7 +77,7 @@ class Script(scripts.Script):
mask.height - down - (mask_blur//2 if down > 0 else 0)
), fill="black")
- processing.torch_gc()
+ devices.torch_gc()
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
@@ -139,7 +139,7 @@ class Script(scripts.Script):
combined_image = images.combine_grid(grid)
if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info)
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
processed = Processed(p, [combined_image], initial_seed, initial_info)
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index 82096b0d..aaece054 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -82,6 +82,6 @@ class Script(scripts.Script):
processed.images.insert(0, grid)
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed)
+ images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, p=p)
return processed
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
new file mode 100644
index 00000000..da2ddd54
--- /dev/null
+++ b/scripts/prompts_from_file.py
@@ -0,0 +1,42 @@
+import math
+import os
+import sys
+import traceback
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules.processing import Processed, process_images
+from PIL import Image
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompts from file"
+
+ def ui(self, is_img2img):
+ file = gr.File(label="File with inputs", type='bytes')
+
+ return [file]
+
+ def run(self, p, data: bytes):
+ lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
+ lines = [x for x in lines if len(x) > 0]
+
+ batch_count = math.ceil(len(lines) / p.batch_size)
+ print(f"Will process {len(lines)} images in {batch_count} batches.")
+
+ p.batch_count = 1
+ p.do_not_save_grid = True
+
+ state.job_count = batch_count
+
+ images = []
+ for batch_no in range(batch_count):
+ state.job = f"{batch_no} out of {batch_count}"
+ p.prompt = lines[batch_no*p.batch_size:(batch_no+1)*p.batch_size]
+ proc = process_images(p)
+ images += proc.images
+
+ return Processed(p, images, p.seed, "")
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index f511cb4a..dd6db81c 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -192,6 +192,6 @@ class Script(scripts.Script):
)
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed)
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, p=p)
return processed
diff --git a/webui.bat b/webui.bat
index 54734d07..6e1e22da 100644
--- a/webui.bat
+++ b/webui.bat
@@ -37,13 +37,15 @@ goto :show_stdout_stderr
set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
%PYTHON% --version
echo venv %PYTHON%
-goto :install_torch
+goto :print_commit
:skip_venv
%PYTHON% --version
-:install_torch
+:print_commit
+%GIT% rev-parse HEAD
+:install_torch
%PYTHON% -c "import torch" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_gpu
echo Installing torch...
diff --git a/webui.py b/webui.py
index ca809f79..953e6620 100644
--- a/webui.py
+++ b/webui.py
@@ -49,7 +49,8 @@ def load_model_from_config(config, ckpt, verbose=False):
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
-
+ if cmd_opts.opt_channelslast:
+ model = model.to(memory_format=torch.channels_last)
model.eval()
return model
@@ -115,7 +116,13 @@ def webui():
run_pnginfo=modules.extras.run_pnginfo
)
- demo.launch(share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, server_port=cmd_opts.port)
+ demo.launch(
+ share=cmd_opts.share,
+ server_name="0.0.0.0" if cmd_opts.listen else None,
+ server_port=cmd_opts.port,
+ debug=cmd_opts.gradio_debug,
+ auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
+ )
if __name__ == "__main__":