aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-29 15:04:21 +0700
committerMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-29 15:04:21 +0700
commit05e2e40537a948d7563d873ffbc394c41a0cd0b1 (patch)
tree5f46c59cab2d36989eed21411dce5271ec540776 /modules
parent16451ca573220e49f2eaaab97580b6b91287c8c4 (diff)
parent35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff)
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py104
-rw-r--r--modules/api/models.py55
-rw-r--r--modules/extras.py178
-rw-r--r--modules/generation_parameters_copypaste.py116
-rw-r--r--modules/hypernetworks/hypernetwork.py13
-rw-r--r--modules/hypernetworks/ui.py3
-rw-r--r--modules/images.py9
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py127
-rw-r--r--modules/scripts.py2
-rw-r--r--modules/sd_models.py18
-rw-r--r--modules/shared.py8
-rw-r--r--modules/textual_inversion/dataset.py4
-rw-r--r--modules/textual_inversion/learn_schedule.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py24
-rw-r--r--modules/ui.py378
16 files changed, 581 insertions, 462 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 6e9d6097..49c213ea 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,46 +1,37 @@
-from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+import uvicorn
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from fastapi import APIRouter, HTTPException
+import modules.shared as shared
+from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_pnginfo
-import modules.shared as shared
-import uvicorn
-from fastapi import Body, APIRouter, HTTPException
-from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field, Json
-from typing import List
-import json
-import io
-import base64
-from PIL import Image
-
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+from modules.extras import run_extras
-class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+def upscaler_to_index(name: str):
+ try:
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+ except:
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
-class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+def setUpscalers(req: dict):
+ reqDict = vars(req)
+ reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
+ reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+ return reqDict
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
- self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
- self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
-
- def __base64_to_image(self, base64_string):
- # if has a comma, deal with prefix
- if "," in base64_string:
- base64_string = base64_string.split(",")[1]
- imgdata = base64.b64decode(base64_string)
- # convert base64 to PIL image
- return Image.open(io.BytesIO(imgdata))
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -60,15 +51,9 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
-
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
-
+ b64images = list(map(encode_pil_to_base64, processed.images))
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -83,7 +68,7 @@ class Api:
mask = img2imgreq.mask
if mask:
- mask = self.__base64_to_image(mask)
+ mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -98,7 +83,7 @@ class Api:
imgs = []
for img in init_images:
- img = self.__base64_to_image(img)
+ img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
@@ -106,21 +91,40 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
+ b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
+
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
- return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ reqDict = setUpscalers(req)
- def extrasapi(self):
- raise NotImplementedError
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
+
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ reqDict = setUpscalers(req)
+
+ def prepareFiles(file):
+ file = decode_base64_to_file(file.data, file_path=file.name)
+ file.orig_name = file.name
+ return file
+
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
+ reqDict.pop('imageList')
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index 079e33d9..dd122321 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,10 +1,10 @@
-from array import array
-from inflection import underscore
-from typing import Any, Dict, Optional
+import inspect
from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional
+from typing_extensions import Literal
+from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-import inspect
-
+from modules.shared import sd_upscalers
API_NOT_ALLOWED = [
"self",
@@ -105,4 +105,47 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
-).generate_model() \ No newline at end of file
+).generate_model()
+
+class TextToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ExtrasBaseRequest(BaseModel):
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+
+class ExtraBaseResponse(BaseModel):
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
+
+class ExtrasSingleImageRequest(ExtrasBaseRequest):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class ExtrasSingleImageResponse(ExtraBaseResponse):
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+
+class FileData(BaseModel):
+ data: str = Field(title="File data", description="Base64 representation of the file")
+ name: str = Field(title="File name")
+
+class ExtrasBatchImagesRequest(ExtrasBaseRequest):
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+
+class ExtrasBatchImagesResponse(ExtraBaseResponse):
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file
diff --git a/modules/extras.py b/modules/extras.py
index 22c5a1c1..4d51088b 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import math
import os
@@ -7,6 +8,10 @@ from PIL import Image
import torch
import tqdm
+from typing import Callable, List, OrderedDict, Tuple
+from functools import partial
+from dataclasses import dataclass
+
from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
@@ -17,10 +22,38 @@ import piexif.helper
import gradio as gr
-cached_images = {}
+class LruCache(OrderedDict):
+ @dataclass(frozen=True)
+ class Key:
+ image_hash: int
+ info_hash: int
+ args_hash: int
+
+ @dataclass
+ class Value:
+ image: Image.Image
+ info: str
+
+ def __init__(self, max_size: int = 5, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._max_size = max_size
+
+ def get(self, key: LruCache.Key) -> LruCache.Value:
+ ret = super().get(key)
+ if ret is not None:
+ self.move_to_end(key) # Move to end of eviction list
+ return ret
+
+ def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
+ self[key] = value
+ while len(self) > self._max_size:
+ self.popitem(last=False)
+
+cached_images: LruCache = LruCache(max_size=5)
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
devices.torch_gc()
imageArr = []
@@ -39,7 +72,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
+ image_list = shared.listfiles(input_dir)
for img in image_list:
try:
image = Image.open(img)
@@ -56,72 +89,102 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
- existing_pnginfo = image.info or {}
+ # Extra operation definitions
- image = image.convert("RGB")
- info = ""
+ def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
- if gfpgan_visibility > 0:
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(image, res, gfpgan_visibility)
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
+ return (res, info)
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- image = res
+ def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
- if codeformer_visibility > 0:
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
+ if codeformer_visibility < 1.0:
+ res = Image.blend(image, res, codeformer_visibility)
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
+ return (res, info)
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- image = res
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ upscaler = shared.sd_upscalers[scaler_index]
+ res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
+ res = cropped
+ return res
+ def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
+ nonlocal upscaling_resize
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+ return (image, info)
+
+ @dataclass
+ class UpscaleParams:
+ upscaler_idx: int
+ blend_alpha: float
+
+ def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ blended_result: Image.Image = None
+ for upscaler in params:
+ upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
+ upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+ cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
+ info_hash=hash(info),
+ args_hash=hash(upscale_args))
+ cached_entry = cached_images.get(cache_key)
+ if cached_entry is None:
+ res = upscale(image, *upscale_args)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+ cached_images.put(cache_key, LruCache.Value(image=res, info=info))
+ else:
+ res, info = cached_entry.image, cached_entry.info
+
+ if blended_result is None:
+ blended_result = res
+ else:
+ blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
+ return (blended_result, info)
+
+ # Build a list of operations to run
+ facefix_ops: List[Callable] = []
+ facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
+ facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
+
+ upscale_ops: List[Callable] = []
+ upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
+
+ if upscaling_resize != 0:
+ step_params: List[UpscaleParams] = []
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
+
+ upscale_ops.append(partial(run_upscalers_blend, step_params))
+
+ extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
+
+ for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
+ existing_pnginfo = image.info or {}
+
+ image = image.convert("RGB")
+ info = ""
+ # Run each operation on each image
+ for op in extras_ops:
+ image, info = op(image, info)
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
- resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
-
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
- c = cropped
- cached_images[key] = c
-
- return c
-
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
-
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
-
- image = res
-
- while len(cached_images) > 2:
- del cached_images[next(iter(cached_images.keys()))]
-
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
@@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), ''
+def clear_cache():
+ cached_images.clear()
+
def run_pnginfo(image):
if image is None:
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index f73647da..bbaad42e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,14 +1,20 @@
+import base64
+import io
import os
import re
import gradio as gr
from modules.shared import script_path
from modules import shared
+import tempfile
+from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
+paste_fields = {}
+bind_list = []
def quote(text):
@@ -20,6 +26,110 @@ def quote(text):
text = text.replace('"', '\\"')
return f'"{text}"'
+
+def image_from_url_text(filedata):
+ if type(filedata) == dict and filedata["is_file"]:
+ filename = filedata["name"]
+ tempdir = os.path.normpath(tempfile.gettempdir())
+ normfn = os.path.normpath(filename)
+ assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
+
+ return Image.open(filename)
+
+ if type(filedata) == list:
+ if len(filedata) == 0:
+ return None
+
+ filedata = filedata[0]
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
+ image = Image.open(io.BytesIO(filedata))
+ return image
+
+
+def add_paste_fields(tabname, init_img, fields):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+
+ # backwards compatibility for existing extensions
+ import modules.ui
+ if tabname == 'txt2img':
+ modules.ui.txt2img_paste_fields = fields
+ elif tabname == 'img2img':
+ modules.ui.img2img_paste_fields = fields
+
+
+def integrate_settings_paste_fields(component_dict):
+ from modules import ui
+
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ for tabname, info in paste_fields.items():
+ if info["fields"] is not None:
+ info["fields"] += settings_paste_fields
+
+
+def create_buttons(tabs_list):
+ buttons = {}
+ for tab in tabs_list:
+ buttons[tab] = gr.Button(f"Send to {tab}")
+ return buttons
+
+
+#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
+def bind_buttons(buttons, send_image, send_generate_info):
+ bind_list.append([buttons, send_image, send_generate_info])
+
+
+def run_bind():
+ for buttons, send_image, send_generate_info in bind_list:
+ for tab in buttons:
+ button = buttons[tab]
+ if send_image and paste_fields[tab]["init_img"]:
+ if type(send_image) == gr.Gallery:
+ button.click(
+ fn=lambda x: image_from_url_text(x),
+ _js="extract_image_from_gallery",
+ inputs=[send_image],
+ outputs=[paste_fields[tab]["init_img"]],
+ )
+ else:
+ button.click(
+ fn=lambda x: x,
+ inputs=[send_image],
+ outputs=[paste_fields[tab]["init_img"]],
+ )
+
+ if send_generate_info and paste_fields[tab]["fields"] is not None:
+ if send_generate_info in paste_fields:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
+
+ button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
+ )
+ else:
+ connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
+
+ button.click(
+ fn=None,
+ _js=f"switch_to_{tab}",
+ inputs=None,
+ outputs=None,
+ )
+
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -68,7 +178,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-def connect_paste(button, paste_fields, input_comp, js=None):
+def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
@@ -106,7 +216,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click(
fn=paste_func,
- _js=js,
+ _js=jsfunc,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 86532063..f45ce199 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
+ "linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
@@ -443,7 +444,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
@@ -453,9 +456,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
@@ -464,8 +467,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
"learn_rate": scheduler.learn_rate
})
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 2c6c0470..aad09ffc 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
+not_available = ["hardswish", "multiheadattention"]
+keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
diff --git a/modules/images.py b/modules/images.py
index 7870b5b7..a0728553 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -300,8 +300,8 @@ class FilenameGenerator:
'seed': lambda self: self.seed if self.seed is not None else '',
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
- 'width': lambda self: self.p and self.p.width,
- 'height': lambda self: self.p and self.p.height,
+ 'width': lambda self: self.image.width,
+ 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
@@ -315,10 +315,11 @@ class FilenameGenerator:
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt):
+ def __init__(self, p, seed, prompt, image):
self.p = p
self.seed = seed
self.prompt = prompt
+ self.image = image
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
"""
- namegen = FilenameGenerator(p, seed, prompt)
+ namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
diff --git a/modules/img2img.py b/modules/img2img.py
index 9c0cf23e..efda26e1 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p)
- images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
diff --git a/modules/processing.py b/modules/processing.py
index 4efba946..548eec29 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return torch.zeros(
+ x.shape[0], 5, 1, 1,
+ dtype=x.dtype,
+ device=x.device
+ )
+
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ return torch.zeros(
+ latent_image.shape[0], 5, 1, 1,
+ dtype=latent_image.dtype,
+ device=latent_image.device
+ )
+
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ return image_conditioning
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x, width=None, height=None):
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
+ image_conditioning = self.img2img_image_conditioning(
+ decoded_samples,
+ samples,
+ decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
+ )
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
@@ -770,33 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
- else:
- self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, 1, 1,
- dtype=self.init_latent.dtype,
- device=self.init_latent.device
- )
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
diff --git a/modules/scripts.py b/modules/scripts.py
index 9323af3e..a7f36012 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -236,7 +236,7 @@ class ScriptRunner:
with gr.Group():
create_script_ui(script, inputs, inputs_alwayson)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e697bb72..f86dc3ed 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -3,6 +3,7 @@ import os.path
import sys
from collections import namedtuple
import torch
+import re
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
@@ -35,8 +36,10 @@ def setup_model():
list_models()
-def checkpoint_tiles():
- return sorted([x.title for x in checkpoints_list.values()])
+def checkpoint_tiles():
+ convert = lambda name: int(name) if name.isdigit() else name.lower()
+ alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
+ return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def list_models():
@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
- missing, extra = model.load_state_dict(sd, strict=False)
+ del pl_sd
+ model.load_state_dict(sd, strict=False)
+ del sd
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -194,9 +199,10 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae)
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False) # LRU
+ if shared.opts.sd_checkpoint_cache > 0:
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
diff --git a/modules/shared.py b/modules/shared.py
index 1a9b8289..fb84afd8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
cmd_opts = parser.parse_args()
restricted_opts = {
@@ -267,6 +268,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -303,6 +305,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
@@ -449,3 +452,8 @@ total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+
+
+def listfiles(dirname):
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
+ return [file for file in filenames if os.path.isfile(file)]
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 5b1c5002..8bb00d27 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(len(self.dataset))
+ self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+ self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index ffec3e1b..2627d585 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False
def step(self, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return False
try:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6b00c6a1..f272e536 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
- if step % shared.opts.training_write_csv_every != 0:
+ if (step + 1) % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
- "epoch": epoch + 1,
+ "epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
@@ -297,15 +296,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.step()
+ steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+ epoch_step = embedding.step % len(ds)
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{embedding.step}'
+ embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -315,8 +315,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate
})
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- forced_filename = f'{embedding_name}-{embedding.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -349,7 +349,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
@@ -365,7 +365,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
diff --git a/modules/ui.py b/modules/ui.py
index 47d16429..98f9565f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,6 +1,4 @@
-import base64
import html
-import io
import json
import math
import mimetypes
@@ -18,13 +16,8 @@ import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np
-import piexif
-import torch
from PIL import Image, PngImagePlugin
-import gradio as gr
-import gradio.utils
-import gradio.routes
from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
@@ -35,7 +28,7 @@ if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.codeformer_model
-import modules.generation_parameters_copypaste
+import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
import modules.hypernetworks.ui
import modules.ldsr_model
@@ -49,13 +42,11 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
+from modules.generation_parameters_copypaste import image_from_url_text
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
-txt2img_paste_fields = []
-img2img_paste_fields = []
-
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
@@ -98,37 +89,11 @@ def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
return text
-
-def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
- filename = filedata["name"]
- tempdir = os.path.normpath(tempfile.gettempdir())
- normfn = os.path.normpath(filename)
- assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
-
- return Image.open(filename)
-
- if type(filedata) == list:
- if len(filedata) == 0:
- return None
-
- filedata = filedata[0]
-
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- filedata = base64.decodebytes(filedata.encode('utf-8'))
- image = Image.open(io.BytesIO(filedata))
- return image
-
-
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
-
return image_from_url_text(x[0])
-
def save_files(js_data, images, do_make_zip, index):
import csv
filenames = []
@@ -192,7 +157,6 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
@@ -626,6 +590,83 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return refresh_button
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row():
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
+
+ open_folder_button.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ do_make_zip,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_info,
+ html_info,
+ html_info,
+ ]
+ )
+ else:
+ html_info_x = gr.HTML()
+ html_info = gr.HTML()
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
+
+
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
@@ -675,30 +716,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
- with gr.Column(variant='panel'):
-
- with gr.Group():
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
-
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- send_to_img2img = gr.Button('Send to img2img')
- send_to_inpaint = gr.Button('Send to inpaint')
- send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -756,23 +775,6 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[hr_options],
)
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- txt2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
- )
-
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
@@ -784,7 +786,6 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- global txt2img_paste_fields
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -807,6 +808,7 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
txt2img_preview_params = [
txt2img_prompt,
@@ -893,30 +895,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
- with gr.Column(variant='panel'):
-
- with gr.Group():
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
-
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- img2img_send_to_img2img = gr.Button('Send to img2img')
- img2img_send_to_inpaint = gr.Button('Send to inpaint')
- img2img_send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -1003,25 +983,9 @@ def create_ui(wrap_gradio_gpu_call):
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- img2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
)
+
roll.click(
fn=roll_artist,
_js="update_img2img_tokens",
@@ -1055,7 +1019,8 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[prompt, negative_prompt, style1, style2],
)
- global img2img_paste_fields
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1074,7 +1039,8 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
*modules.scripts.scripts_img2img.infotext_fields
]
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
@@ -1087,12 +1053,8 @@ def create_ui(wrap_gradio_gpu_call):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
- placeholder="A directory on the same machine where the server is running."
- )
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
- placeholder="Leave blank to save images to the default path."
- )
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
with gr.Tabs(elem_id="extras_resize_mode"):
@@ -1104,9 +1066,9 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
-
+
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -1119,17 +1081,12 @@ def create_ui(wrap_gradio_gpu_call):
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+ with gr.Group():
+ upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
- with gr.Column(variant='panel'):
- result_images = gr.Gallery(label="Result", show_label=False)
- html_info_x = gr.HTML()
- html_info = gr.HTML()
- extras_send_to_img2img = gr.Button('Send to img2img')
- extras_send_to_inpaint = gr.Button('Send to inpaint')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
- open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+ result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
@@ -1152,6 +1109,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_upscaler_1,
extras_upscaler_2,
extras_upscaler_2_visibility,
+ upscale_before_face_fix,
],
outputs=[
result_images,
@@ -1159,19 +1117,11 @@ def create_ui(wrap_gradio_gpu_call):
html_info,
]
)
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
- extras_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[result_images],
- outputs=[init_img],
- )
-
- extras_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[result_images],
- outputs=[init_img_with_mask],
+ extras_image.change(
+ fn=modules.extras.clear_cache,
+ inputs=[], outputs=[]
)
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
@@ -1183,17 +1133,16 @@ def create_ui(wrap_gradio_gpu_call):
html = gr.HTML()
generation_info = gr.Textbox(visible=False)
html2 = gr.HTML()
-
with gr.Row():
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+ parameters_copypaste.bind_buttons(buttons, image, generation_info)
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
-
+
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1238,7 +1187,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
@@ -1497,28 +1446,6 @@ def create_ui(wrap_gradio_gpu_call):
script_callbacks.ui_settings_callback()
opts.reorder()
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- else:
- sp.Popen(["xdg-open", path])
-
def run_settings(*args):
changed = 0
@@ -1672,10 +1599,6 @@ Requested path was: {f}
(train_interface, "Train", "ti"),
]
- interfaces += script_callbacks.ui_tabs_callback()
-
- interfaces += [(settings_interface, "Settings", "settings")]
-
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
@@ -1692,6 +1615,9 @@ Requested path was: {f}
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
@@ -1700,6 +1626,9 @@ Requested path was: {f}
settings_interface.gradio_ref = demo
+ parameters_copypaste.integrate_settings_paste_fields(component_dict)
+ parameters_copypaste.run_bind()
+
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
@@ -1753,85 +1682,6 @@ Requested path was: {f}
component_dict['sd_model_checkpoint'],
]
)
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
- txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
- img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
- send_to_img2img.click(
- fn=lambda img, *args: (image_from_url_text(img),*args),
- _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img] + img2img_fields,
- )
-
- send_to_inpaint.click(
- fn=lambda x, *args: (image_from_url_text(x), *args),
- _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img_with_mask] + img2img_fields,
- )
-
- img2img_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[img2img_gallery],
- outputs=[init_img],
- )
-
- img2img_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[img2img_gallery],
- outputs=[init_img_with_mask],
- )
-
- send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[txt2img_gallery],
- outputs=[extras_image],
- )
-
- open_txt2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_img2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_extras_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
- inputs=[],
- outputs=[],
- )
-
- img2img_send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[img2img_gallery],
- outputs=[extras_image],
- )
-
- settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'sd_model_checkpoint': 'Model hash',
- }
-
- settings_paste_fields = [
- (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
- for k, v in settings_map.items()
- ]
-
- modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
- modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
-
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1851,7 +1701,7 @@ Requested path was: {f}
def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
- if getattr(obj,'custom_script_source',None) is not None:
+ if getattr(obj, 'custom_script_source', None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
if getattr(obj, 'do_not_save_to_config', False):
@@ -1911,7 +1761,7 @@ def load_javascript(raw_response):
javascript = f'<script>{jsfile.read()}</script>'
scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
+
for basedir, filename, path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"