aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py482
-rw-r--r--modules/api/models.py261
-rw-r--r--modules/api/processing.py99
-rw-r--r--modules/bsrgan_model.py76
-rw-r--r--modules/bsrgan_model_arch.py102
-rw-r--r--modules/call_queue.py98
-rw-r--r--modules/codeformer/vqgan_arch.py4
-rw-r--r--modules/codeformer_model.py3
-rw-r--r--modules/deepbooru.py259
-rw-r--r--modules/deepbooru_model.py676
-rw-r--r--modules/devices.py112
-rw-r--r--modules/errors.py25
-rw-r--r--modules/esrgan_model.py199
-rw-r--r--modules/esrgan_model_arch.py487
-rw-r--r--modules/extensions.py99
-rw-r--r--modules/extras.py320
-rw-r--r--modules/generation_parameters_copypaste.py243
-rw-r--r--modules/gfpgan_model.py4
-rw-r--r--modules/hypernetworks/hypernetwork.py535
-rw-r--r--modules/hypernetworks/ui.py29
-rw-r--r--modules/images.py380
-rw-r--r--modules/images_history.py183
-rw-r--r--modules/img2img.py50
-rw-r--r--modules/import_hook.py5
-rw-r--r--modules/interrogate.py36
-rw-r--r--modules/ldsr_model.py54
-rw-r--r--modules/ldsr_model_arch.py222
-rw-r--r--modules/localization.py6
-rw-r--r--modules/lowvram.py50
-rw-r--r--modules/masking.py2
-rw-r--r--modules/memmon.py3
-rw-r--r--modules/modelloader.py44
-rw-r--r--modules/ngrok.py15
-rw-r--r--modules/paths.py2
-rw-r--r--modules/processing.py473
-rw-r--r--modules/safe.py113
-rw-r--r--modules/safety.py42
-rw-r--r--modules/script_callbacks.py281
-rw-r--r--modules/script_loading.py34
-rw-r--r--modules/scripts.py397
-rw-r--r--modules/scunet_model.py88
-rw-r--r--modules/scunet_model_arch.py265
-rw-r--r--modules/sd_hijack.py364
-rw-r--r--modules/sd_hijack_checkpoint.py10
-rw-r--r--modules/sd_hijack_clip.py303
-rw-r--r--modules/sd_hijack_inpainting.py111
-rw-r--r--modules/sd_hijack_open_clip.py37
-rw-r--r--modules/sd_hijack_optimizations.py10
-rw-r--r--modules/sd_hijack_unet.py30
-rw-r--r--modules/sd_hijack_xlmr.py34
-rw-r--r--modules/sd_models.py238
-rw-r--r--modules/sd_samplers.py261
-rw-r--r--modules/sd_vae.py231
-rw-r--r--modules/sd_vae_approx.py58
-rw-r--r--modules/shared.py283
-rw-r--r--modules/styles.py11
-rw-r--r--modules/swinir_model.py161
-rw-r--r--modules/swinir_model_arch.py867
-rw-r--r--modules/swinir_model_arch_v2.py1017
-rw-r--r--modules/textual_inversion/autocrop.py341
-rw-r--r--modules/textual_inversion/dataset.py141
-rw-r--r--modules/textual_inversion/image_embedding.py5
-rw-r--r--modules/textual_inversion/learn_schedule.py37
-rw-r--r--modules/textual_inversion/preprocess.py204
-rw-r--r--modules/textual_inversion/textual_inversion.py453
-rw-r--r--modules/textual_inversion/ui.py13
-rw-r--r--modules/txt2img.py20
-rw-r--r--modules/ui.py1306
-rw-r--r--modules/ui_components.py25
-rw-r--r--modules/ui_extensions.py328
-rw-r--r--modules/ui_tempdir.py82
-rw-r--r--modules/upscaler.py31
-rw-r--r--modules/xlmr.py137
73 files changed, 8383 insertions, 5624 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..48a70a44 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,67 +1,461 @@
-from modules.api.processing import StableDiffusionProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, process_images
-from modules.sd_samplers import all_samplers
-from modules.extras import run_pnginfo
-import modules.shared as shared
-import uvicorn
-from fastapi import Body, APIRouter, HTTPException
-from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field, Json
-import json
-import io
import base64
+import io
+import time
+import datetime
+import uvicorn
+from threading import Lock
+from io import BytesIO
+from gradio.processing_utils import decode_base64_to_file
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from secrets import compare_digest
+
+import modules.shared as shared
+from modules import sd_samplers, deepbooru, sd_hijack
+from modules.api.models import *
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.extras import run_extras, run_pnginfo
+from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
+from modules.textual_inversion.preprocess import preprocess
+from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
+from PIL import PngImagePlugin,Image
+from modules.sd_models import checkpoints_list, find_checkpoint_config
+from modules.realesrgan_model import get_realesrgan_models
+from modules import devices
+from typing import List
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+def upscaler_to_index(name: str):
+ try:
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+ except:
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
-class TextToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+
+def validate_sampler_name(name):
+ config = sd_samplers.all_samplers_map.get(name, None)
+ if config is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+ return name
+
+def setUpscalers(req: dict):
+ reqDict = vars(req)
+ reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
+ reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+ return reqDict
+
+def decode_base64_to_image(encoding):
+ if encoding.startswith("data:image/"):
+ encoding = encoding.split(";")[1].split(",")[1]
+ return Image.open(BytesIO(base64.b64decode(encoding)))
+
+def encode_pil_to_base64(image):
+ with io.BytesIO() as output_bytes:
+
+ # Copy any text-only metadata
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ image.save(
+ output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
+ )
+ bytes_data = output_bytes.getvalue()
+ return base64.b64encode(bytes_data)
+
+def api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ endpoint = endpoint,
+ duration = duration,
+ ))
+ return res
class Api:
- def __init__(self, app, queue_lock):
+ def __init__(self, app: FastAPI, queue_lock: Lock):
+ if shared.cmd_opts.api_auth:
+ self.credentials = dict()
+ for auth in shared.cmd_opts.api_auth.split(","):
+ user, password = auth.split(":")
+ self.credentials[user] = password
+
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
- self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
-
- def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
- sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
- if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
+ api_middleware(self.app)
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
+ self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
+
+ def add_api_route(self, path: str, endpoint, **kwargs):
+ if shared.cmd_opts.api_auth:
+ return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
+ return self.app.add_api_route(path, endpoint, **kwargs)
+
+ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
+ if credentials.username in self.credentials:
+ if compare_digest(credentials.password, self.credentials[credentials.username]):
+ return True
+
+ raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
- "sampler_index": sampler_index[0],
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
- p = StableDiffusionProcessingTxt2Img(**vars(populate))
- # Override object param
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
+
with self.queue_lock:
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+
+ shared.state.begin()
processed = process_images(p)
-
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
+ shared.state.end()
+
+
+ b64images = list(map(encode_pil_to_base64, processed.images))
+
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+ mask = img2imgreq.mask
+ if mask:
+ mask = decode_base64_to_image(mask)
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
+ "do_not_save_samples": True,
+ "do_not_save_grid": True,
+ "mask": mask
+ }
+ )
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
+
+ args = vars(populate)
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+
+ with self.queue_lock:
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+
+ shared.state.begin()
+ processed = process_images(p)
+ shared.state.end()
+
+ b64images = list(map(encode_pil_to_base64, processed.images))
+
+ if not img2imgreq.include_init_images:
+ img2imgreq.init_images = None
+ img2imgreq.mask = None
+
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ reqDict = setUpscalers(req)
+
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ reqDict = setUpscalers(req)
+
+ def prepareFiles(file):
+ file = decode_base64_to_file(file.data, file_path=file.name)
+ file.orig_name = file.name
+ return file
+
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
+ reqDict.pop('imageList')
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+
+ def pnginfoapi(self, req: PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+ return PNGInfoResponse(info=result[1])
+
+ def progressapi(self, req: ProgressRequest = Depends()):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ shared.state.set_current_image()
+
+ current_image = None
+ if shared.state.current_image and not req.skip_current_image:
+ current_image = encode_pil_to_base64(shared.state.current_image)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ img = decode_base64_to_image(image_b64)
+ img = img.convert('RGB')
+
+ # Override object param
+ with self.queue_lock:
+ if interrogatereq.model == "clip":
+ processed = shared.interrogator.interrogate(img)
+ elif interrogatereq.model == "deepdanbooru":
+ processed = deepbooru.model.tag(img)
+ else:
+ raise HTTPException(status_code=404, detail="Model not found")
+
+ return InterrogateResponse(caption=processed)
+
+ def interruptapi(self):
+ shared.state.interrupt()
+
+ return {}
+
+ def skip(self):
+ shared.state.skip()
+
+ def get_config(self):
+ options = {}
+ for key in shared.opts.data.keys():
+ metadata = shared.opts.data_labels.get(key)
+ if(metadata is not None):
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+ else:
+ options.update({key: shared.opts.data.get(key, None)})
+
+ return options
+
+ def set_config(self, req: Dict[str, Any]):
+ for k, v in req.items():
+ shared.opts.set(k, v)
+
+ shared.opts.save(shared.config_filename)
+ return
+
+ def get_cmd_flags(self):
+ return vars(shared.cmd_opts)
+
+ def get_samplers(self):
+ return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
+
+ def get_upscalers(self):
+ upscalers = []
+
+ for upscaler in shared.sd_upscalers:
+ u = upscaler.scaler
+ upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
+
+ return upscalers
+
+ def get_sd_models(self):
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+
+ def get_hypernetworks(self):
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+ def get_face_restorers(self):
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+ def get_realesrgan_models(self):
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+ def get_prompt_styles(self):
+ styleList = []
+ for k in shared.prompt_styles.styles:
+ style = shared.prompt_styles.styles[k]
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
+
+ return styleList
+
+ def get_artists_categories(self):
+ return shared.artist_db.cats
+
+ def get_artists(self):
+ return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+ return {
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
+ }
+
+ def refresh_checkpoints(self):
+ shared.refresh_checkpoints()
+
+ def create_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_embedding(**args) # create empty embedding
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
+ shared.state.end()
+ return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create embedding error: {error}".format(error = e))
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
-
-
+ def create_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_hypernetwork(**args) # create empty embedding
+ shared.state.end()
+ return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
- def img2imgapi(self):
- raise NotImplementedError
+ def preprocess(self, args: dict):
+ try:
+ shared.state.begin()
+ preprocess(**args) # quick operation unless blip/booru interrogation is enabled
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess complete')
+ except KeyError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
+ except AssertionError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
+ except FileNotFoundError as e:
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
- def extrasapi(self):
- raise NotImplementedError
+ def train_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ embedding, filename = train_embedding(**args) # can take a long time to complete
+ except Exception as e:
+ error = e
+ finally:
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
- def pnginfoapi(self):
- raise NotImplementedError
+ def train_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ initial_hypernetwork = shared.loaded_hypernetwork
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ hypernetwork, filename = train_hypernetwork(*args)
+ except Exception as e:
+ error = e
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info = "train embedding error: {error}".format(error = error))
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
new file mode 100644
index 00000000..4a632c68
--- /dev/null
+++ b/modules/api/models.py
@@ -0,0 +1,261 @@
+import inspect
+from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional
+from typing_extensions import Literal
+from inflection import underscore
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
+from modules.shared import sd_upscalers, opts, parser
+from typing import Dict, List
+
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+ field_exclude: bool = False
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"],
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+).generate_model()
+
+class TextToImageResponse(BaseModel):
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ImageToImageResponse(BaseModel):
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
+ info: str
+
+class ExtrasBaseRequest(BaseModel):
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
+
+class ExtraBaseResponse(BaseModel):
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
+
+class ExtrasSingleImageRequest(ExtrasBaseRequest):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class ExtrasSingleImageResponse(ExtraBaseResponse):
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+
+class FileData(BaseModel):
+ data: str = Field(title="File data", description="Base64 representation of the file")
+ name: str = Field(title="File name")
+
+class ExtrasBatchImagesRequest(ExtrasBaseRequest):
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+
+class ExtrasBatchImagesResponse(ExtraBaseResponse):
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with all the info the image had")
+
+class ProgressRequest(BaseModel):
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
+class ProgressResponse(BaseModel):
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+ eta_relative: float = Field(title="ETA in secs")
+ state: dict = Field(title="State", description="The current state snapshot")
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+class InterrogateRequest(BaseModel):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
+
+class InterrogateResponse(BaseModel):
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+
+class TrainResponse(BaseModel):
+ info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
+
+class CreateResponse(BaseModel):
+ info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
+
+class PreprocessResponse(BaseModel):
+ info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
+
+fields = {}
+for key, metadata in opts.data_labels.items():
+ value = opts.data.get(key)
+ optType = opts.typemap.get(type(metadata.default), type(value))
+
+ if (metadata is not None):
+ fields.update({key: (Optional[optType], Field(
+ default=metadata.default ,description=metadata.label))})
+ else:
+ fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+ if(_options[key].dest != 'help'):
+ flag = _options[key]
+ _type = str
+ if _options[key].default is not None: _type = type(_options[key].default)
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+ name: str = Field(title="Name")
+ aliases: List[str] = Field(title="Aliases")
+ options: Dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+ name: str = Field(title="Name")
+ model_name: Optional[str] = Field(title="Model Name")
+ model_path: Optional[str] = Field(title="Path")
+ model_url: Optional[str] = Field(title="URL")
+
+class SDModelItem(BaseModel):
+ title: str = Field(title="Title")
+ model_name: str = Field(title="Model Name")
+ hash: str = Field(title="Hash")
+ filename: str = Field(title="Filename")
+ config: str = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+ name: str = Field(title="Name")
+ cmd_dir: Optional[str] = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+ scale: Optional[int] = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+ name: str = Field(title="Name")
+ prompt: Optional[str] = Field(title="Prompt")
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+ name: str = Field(title="Name")
+ score: float = Field(title="Score")
+ category: str = Field(title="Category")
+
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file
diff --git a/modules/api/processing.py b/modules/api/processing.py
deleted file mode 100644
index 4c541241..00000000
--- a/modules/api/processing.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from inflection import underscore
-from typing import Any, Dict, Optional
-from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img
-import inspect
-
-
-API_NOT_ALLOWED = [
- "self",
- "kwargs",
- "sd_model",
- "outpath_samples",
- "outpath_grids",
- "sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
- "extra_generation_params",
- "overlay_images",
- "do_not_reload_embeddings",
- "seed_enable_extras",
- "prompt_for_display",
- "sampler_noise_scheduler_override",
- "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
- """Assistance Class for Pydantic Dynamic Model Generation"""
-
- field: str
- field_alias: str
- field_type: Any
- field_value: Any
-
-
-class PydanticModelGenerator:
- """
- Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
- source_data is a snapshot of the default values produced by the class
- params are the names of the actual keys required by __init__
- """
-
- def __init__(
- self,
- model_name: str = None,
- class_instance = None,
- additional_fields = None,
- ):
- def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
- field_type = v.annotation
-
- return Optional[field_type]
-
- def merge_class_params(class_):
- all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
- parameters = {}
- for classes in all_classes:
- parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
- return parameters
-
-
- self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
- self._model_def = [
- ModelDef(
- field=underscore(k),
- field_alias=k,
- field_type=field_type_generator(k, v),
- field_value=v.default
- )
- for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
- ]
-
- for fields in additional_fields:
- self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
- field_type=fields["type"],
- field_value=fields["default"]))
-
- def generate_model(self):
- """
- Creates a pydantic BaseModel
- from the json and overrides provided at initialization
- """
- fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
- }
- DynamicModel = create_model(self._model_name, **fields)
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- return DynamicModel
-
-StableDiffusionProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
- StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
-).generate_model() \ No newline at end of file
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
deleted file mode 100644
index 737e1a76..00000000
--- a/modules/bsrgan_model.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.bsrgan_model_arch import RRDBNet
-
-
-class UpscalerBSRGAN(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "BSRGAN"
- self.model_name = "BSRGAN 4x"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
- scalers = []
- if len(model_paths) == 0:
- scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
- scalers.append(scaler_data)
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- return img
- model.to(devices.device_bsrgan)
- torch.cuda.empty_cache()
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_bsrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
- return None
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- return model
-
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
deleted file mode 100644
index cb4d1c13..00000000
--- a/modules/bsrgan_model_arch.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
-
-
-def initialize_weights(net_l, scale=1):
- if not isinstance(net_l, list):
- net_l = [net_l]
- for net in net_l:
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale # for residual block
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias.data, 0.0)
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
- self.sf = sf
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- if self.sf==4:
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- if self.sf==4:
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out \ No newline at end of file
diff --git a/modules/call_queue.py b/modules/call_queue.py
new file mode 100644
index 00000000..4cd49533
--- /dev/null
+++ b/modules/call_queue.py
@@ -0,0 +1,98 @@
+import html
+import sys
+import threading
+import traceback
+import time
+
+from modules import shared
+
+queue_lock = threading.Lock()
+
+
+def wrap_queued_call(func):
+ def f(*args, **kwargs):
+ with queue_lock:
+ res = func(*args, **kwargs)
+
+ return res
+
+ return f
+
+
+def wrap_gradio_gpu_call(func, extra_outputs=None):
+ def f(*args, **kwargs):
+
+ shared.state.begin()
+
+ with queue_lock:
+ res = func(*args, **kwargs)
+
+ shared.state.end()
+
+ return res
+
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
+
+
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
+ if run_memmon:
+ shared.mem_mon.monitor()
+ t = time.perf_counter()
+
+ try:
+ res = list(func(*args, **kwargs))
+ except Exception as e:
+ # When printing out our debug argument list, do not print out more than a MB of text
+ max_debug_str_len = 131072 # (1024*1024)/8
+
+ print("Error completing request", file=sys.stderr)
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ print(argStr[:max_debug_str_len], file=sys.stderr)
+ if len(argStr) > max_debug_str_len:
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.state.job = ""
+ shared.state.job_count = 0
+
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
+
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.job_count = 0
+
+ if not add_stats:
+ return tuple(res)
+
+ elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"{elapsed_s:.2f}s"
+ if elapsed_m > 0:
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
+
+ if run_memmon:
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
+ active_peak = mem_stats['active_peak']
+ reserved_peak = mem_stats['reserved_peak']
+ sys_peak = mem_stats['system_peak']
+ sys_total = mem_stats['total']
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+
+ vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+ else:
+ vram_html = ''
+
+ # last item is always HTML
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+
+ return tuple(res)
+
+ return f
+
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index c06c590c..e7293683 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
- raise ValueError(f'Wrong params!')
+ raise ValueError('Wrong params!')
def forward(self, x):
@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
- raise ValueError(f'Wrong params!')
+ raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x) \ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index e6d9fa4f..ab40d842 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -36,6 +36,7 @@ def setup_model(dirname):
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
+ from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer
@@ -65,6 +66,8 @@ def setup_model(dirname):
net.load_state_dict(checkpoint)
net.eval()
+ if hasattr(retinaface, 'device'):
+ retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 8914662d..122fce7f 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -1,172 +1,99 @@
-import os.path
-from concurrent.futures import ProcessPoolExecutor
-import multiprocessing
-import time
+import os
import re
+import torch
+from PIL import Image
+import numpy as np
+
+from modules import modelloader, paths, deepbooru_model, devices, images, shared
+
re_special = re.compile(r'([\\()])')
-def get_deepbooru_tags(pil_image):
- """
- This method is for running only one image at a time for simple use. Used to the img2img interrogate.
- """
- from modules import shared # prevents circular reference
-
- try:
- create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
- return get_tags_from_process(pil_image)
- finally:
- release_process()
-
-
-OPT_INCLUDE_RANKS = "include_ranks"
-def create_deepbooru_opts():
- from modules import shared
-
- return {
- "use_spaces": shared.opts.deepbooru_use_spaces,
- "use_escape": shared.opts.deepbooru_escape,
- "alpha_sort": shared.opts.deepbooru_sort_alpha,
- OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
- }
-
-
-def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
- model, tags = get_deepbooru_tags_model()
- while True: # while process is running, keep monitoring queue for new image
- pil_image = queue.get()
- if pil_image == "QUIT":
- break
- else:
- deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
-
-
-def create_deepbooru_process(threshold, deepbooru_opts):
- """
- Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
- to be processed in a row without reloading the model or creating a new process. To return the data, a shared
- dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
- to the dictionary and the method adding the image to the queue should wait for this value to be updated with
- the tags.
- """
- from modules import shared # prevents circular reference
- shared.deepbooru_process_manager = multiprocessing.Manager()
- shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
- shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
- shared.deepbooru_process.start()
-
-
-def get_tags_from_process(image):
- from modules import shared
-
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- caption = shared.deepbooru_process_return["value"]
- shared.deepbooru_process_return["value"] = -1
-
- return caption
-
-
-def release_process():
- """
- Stops the deepbooru process to return used memory
- """
- from modules import shared # prevents circular reference
- shared.deepbooru_process_queue.put("QUIT")
- shared.deepbooru_process.join()
- shared.deepbooru_process_queue = None
- shared.deepbooru_process = None
- shared.deepbooru_process_return = None
- shared.deepbooru_process_manager = None
-
-def get_deepbooru_tags_model():
- import deepdanbooru as dd
- import tensorflow as tf
- import numpy as np
- this_folder = os.path.dirname(__file__)
- model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
- if not os.path.exists(os.path.join(model_path, 'project.json')):
- # there is no point importing these every time
- import zipfile
- from basicsr.utils.download_util import load_file_from_url
- load_file_from_url(
- r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
- model_path)
- with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
- zip_ref.extractall(model_path)
- os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
-
- tags = dd.project.load_tags_from_project(model_path)
- model = dd.project.load_model_from_project(
- model_path, compile_model=False
- )
- return model, tags
-
-
-def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
- import deepdanbooru as dd
- import tensorflow as tf
- import numpy as np
-
- alpha_sort = deepbooru_opts['alpha_sort']
- use_spaces = deepbooru_opts['use_spaces']
- use_escape = deepbooru_opts['use_escape']
- include_ranks = deepbooru_opts['include_ranks']
-
- width = model.input_shape[2]
- height = model.input_shape[1]
- image = np.array(pil_image)
- image = tf.image.resize(
- image,
- size=(height, width),
- method=tf.image.ResizeMethod.AREA,
- preserve_aspect_ratio=True,
- )
- image = image.numpy() # EagerTensor to np.array
- image = dd.image.transform_and_pad_image(image, width, height)
- image = image / 255.0
- image_shape = image.shape
- image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
-
- y = model.predict(image)[0]
-
- result_dict = {}
-
- for i, tag in enumerate(tags):
- result_dict[tag] = y[i]
-
- unsorted_tags_in_theshold = []
- result_tags_print = []
- for tag in tags:
- if result_dict[tag] >= threshold:
+
+class DeepDanbooru:
+ def __init__(self):
+ self.model = None
+
+ def load(self):
+ if self.model is not None:
+ return
+
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
+ model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
+ ext_filter=[".pt"],
+ download_name='model-resnet_custom_v3.pt',
+ )
+
+ self.model = deepbooru_model.DeepDanbooruModel()
+ self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
+
+ self.model.eval()
+ self.model.to(devices.cpu, devices.dtype)
+
+ def start(self):
+ self.load()
+ self.model.to(devices.device)
+
+ def stop(self):
+ if not shared.opts.interrogate_keep_models_in_memory:
+ self.model.to(devices.cpu)
+ devices.torch_gc()
+
+ def tag(self, pil_image):
+ self.start()
+ res = self.tag_multi(pil_image)
+ self.stop()
+
+ return res
+
+ def tag_multi(self, pil_image, force_disable_ranks=False):
+ threshold = shared.opts.interrogate_deepbooru_score_threshold
+ use_spaces = shared.opts.deepbooru_use_spaces
+ use_escape = shared.opts.deepbooru_escape
+ alpha_sort = shared.opts.deepbooru_sort_alpha
+ include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
+
+ pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
+
+ with torch.no_grad(), devices.autocast():
+ x = torch.from_numpy(a).to(devices.device)
+ y = self.model(x)[0].detach().cpu().numpy()
+
+ probability_dict = {}
+
+ for tag, probability in zip(self.model.tags, y):
+ if probability < threshold:
+ continue
+
if tag.startswith("rating:"):
continue
- unsorted_tags_in_theshold.append((result_dict[tag], tag))
- result_tags_print.append(f'{result_dict[tag]} {tag}')
-
- # sort tags
- result_tags_out = []
- sort_ndx = 0
- if alpha_sort:
- sort_ndx = 1
-
- # sort by reverse by likelihood and normal for alpha, and format tag text as requested
- unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
- for weight, tag in unsorted_tags_in_theshold:
- tag_outformat = tag
- if use_spaces:
- tag_outformat = tag_outformat.replace('_', ' ')
- if use_escape:
- tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
- if include_ranks:
- tag_outformat = f"({tag_outformat}:{weight:.3f})"
-
- result_tags_out.append(tag_outformat)
-
- print('\n'.join(sorted(result_tags_print, reverse=True)))
-
- return ', '.join(result_tags_out)
+
+ probability_dict[tag] = probability
+
+ if alpha_sort:
+ tags = sorted(probability_dict)
+ else:
+ tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
+
+ res = []
+
+ filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+
+ for tag in [x for x in tags if x not in filtertags]:
+ probability = probability_dict[tag]
+ tag_outformat = tag
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{probability:.3f})"
+
+ res.append(tag_outformat)
+
+ return ", ".join(res)
+
+
+model = DeepDanbooru()
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
new file mode 100644
index 00000000..edd40c81
--- /dev/null
+++ b/modules/deepbooru_model.py
@@ -0,0 +1,676 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
+
+
+class DeepDanbooruModel(nn.Module):
+ def __init__(self):
+ super(DeepDanbooruModel, self).__init__()
+
+ self.tags = []
+
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
+
+ def forward(self, *inputs):
+ t_358, = inputs
+ t_359 = t_358.permute(*[0, 3, 1, 2])
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
+ t_360 = self.n_Conv_0(t_359_padded)
+ t_361 = F.relu(t_360)
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
+ t_362 = self.n_MaxPool_0(t_361)
+ t_363 = self.n_Conv_1(t_362)
+ t_364 = self.n_Conv_2(t_362)
+ t_365 = F.relu(t_364)
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
+ t_366 = self.n_Conv_3(t_365_padded)
+ t_367 = F.relu(t_366)
+ t_368 = self.n_Conv_4(t_367)
+ t_369 = torch.add(t_368, t_363)
+ t_370 = F.relu(t_369)
+ t_371 = self.n_Conv_5(t_370)
+ t_372 = F.relu(t_371)
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
+ t_373 = self.n_Conv_6(t_372_padded)
+ t_374 = F.relu(t_373)
+ t_375 = self.n_Conv_7(t_374)
+ t_376 = torch.add(t_375, t_370)
+ t_377 = F.relu(t_376)
+ t_378 = self.n_Conv_8(t_377)
+ t_379 = F.relu(t_378)
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
+ t_380 = self.n_Conv_9(t_379_padded)
+ t_381 = F.relu(t_380)
+ t_382 = self.n_Conv_10(t_381)
+ t_383 = torch.add(t_382, t_377)
+ t_384 = F.relu(t_383)
+ t_385 = self.n_Conv_11(t_384)
+ t_386 = self.n_Conv_12(t_384)
+ t_387 = F.relu(t_386)
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
+ t_388 = self.n_Conv_13(t_387_padded)
+ t_389 = F.relu(t_388)
+ t_390 = self.n_Conv_14(t_389)
+ t_391 = torch.add(t_390, t_385)
+ t_392 = F.relu(t_391)
+ t_393 = self.n_Conv_15(t_392)
+ t_394 = F.relu(t_393)
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
+ t_395 = self.n_Conv_16(t_394_padded)
+ t_396 = F.relu(t_395)
+ t_397 = self.n_Conv_17(t_396)
+ t_398 = torch.add(t_397, t_392)
+ t_399 = F.relu(t_398)
+ t_400 = self.n_Conv_18(t_399)
+ t_401 = F.relu(t_400)
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
+ t_402 = self.n_Conv_19(t_401_padded)
+ t_403 = F.relu(t_402)
+ t_404 = self.n_Conv_20(t_403)
+ t_405 = torch.add(t_404, t_399)
+ t_406 = F.relu(t_405)
+ t_407 = self.n_Conv_21(t_406)
+ t_408 = F.relu(t_407)
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
+ t_409 = self.n_Conv_22(t_408_padded)
+ t_410 = F.relu(t_409)
+ t_411 = self.n_Conv_23(t_410)
+ t_412 = torch.add(t_411, t_406)
+ t_413 = F.relu(t_412)
+ t_414 = self.n_Conv_24(t_413)
+ t_415 = F.relu(t_414)
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
+ t_416 = self.n_Conv_25(t_415_padded)
+ t_417 = F.relu(t_416)
+ t_418 = self.n_Conv_26(t_417)
+ t_419 = torch.add(t_418, t_413)
+ t_420 = F.relu(t_419)
+ t_421 = self.n_Conv_27(t_420)
+ t_422 = F.relu(t_421)
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
+ t_423 = self.n_Conv_28(t_422_padded)
+ t_424 = F.relu(t_423)
+ t_425 = self.n_Conv_29(t_424)
+ t_426 = torch.add(t_425, t_420)
+ t_427 = F.relu(t_426)
+ t_428 = self.n_Conv_30(t_427)
+ t_429 = F.relu(t_428)
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
+ t_430 = self.n_Conv_31(t_429_padded)
+ t_431 = F.relu(t_430)
+ t_432 = self.n_Conv_32(t_431)
+ t_433 = torch.add(t_432, t_427)
+ t_434 = F.relu(t_433)
+ t_435 = self.n_Conv_33(t_434)
+ t_436 = F.relu(t_435)
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
+ t_437 = self.n_Conv_34(t_436_padded)
+ t_438 = F.relu(t_437)
+ t_439 = self.n_Conv_35(t_438)
+ t_440 = torch.add(t_439, t_434)
+ t_441 = F.relu(t_440)
+ t_442 = self.n_Conv_36(t_441)
+ t_443 = self.n_Conv_37(t_441)
+ t_444 = F.relu(t_443)
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
+ t_445 = self.n_Conv_38(t_444_padded)
+ t_446 = F.relu(t_445)
+ t_447 = self.n_Conv_39(t_446)
+ t_448 = torch.add(t_447, t_442)
+ t_449 = F.relu(t_448)
+ t_450 = self.n_Conv_40(t_449)
+ t_451 = F.relu(t_450)
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
+ t_452 = self.n_Conv_41(t_451_padded)
+ t_453 = F.relu(t_452)
+ t_454 = self.n_Conv_42(t_453)
+ t_455 = torch.add(t_454, t_449)
+ t_456 = F.relu(t_455)
+ t_457 = self.n_Conv_43(t_456)
+ t_458 = F.relu(t_457)
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
+ t_459 = self.n_Conv_44(t_458_padded)
+ t_460 = F.relu(t_459)
+ t_461 = self.n_Conv_45(t_460)
+ t_462 = torch.add(t_461, t_456)
+ t_463 = F.relu(t_462)
+ t_464 = self.n_Conv_46(t_463)
+ t_465 = F.relu(t_464)
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
+ t_466 = self.n_Conv_47(t_465_padded)
+ t_467 = F.relu(t_466)
+ t_468 = self.n_Conv_48(t_467)
+ t_469 = torch.add(t_468, t_463)
+ t_470 = F.relu(t_469)
+ t_471 = self.n_Conv_49(t_470)
+ t_472 = F.relu(t_471)
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
+ t_473 = self.n_Conv_50(t_472_padded)
+ t_474 = F.relu(t_473)
+ t_475 = self.n_Conv_51(t_474)
+ t_476 = torch.add(t_475, t_470)
+ t_477 = F.relu(t_476)
+ t_478 = self.n_Conv_52(t_477)
+ t_479 = F.relu(t_478)
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
+ t_480 = self.n_Conv_53(t_479_padded)
+ t_481 = F.relu(t_480)
+ t_482 = self.n_Conv_54(t_481)
+ t_483 = torch.add(t_482, t_477)
+ t_484 = F.relu(t_483)
+ t_485 = self.n_Conv_55(t_484)
+ t_486 = F.relu(t_485)
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
+ t_487 = self.n_Conv_56(t_486_padded)
+ t_488 = F.relu(t_487)
+ t_489 = self.n_Conv_57(t_488)
+ t_490 = torch.add(t_489, t_484)
+ t_491 = F.relu(t_490)
+ t_492 = self.n_Conv_58(t_491)
+ t_493 = F.relu(t_492)
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
+ t_494 = self.n_Conv_59(t_493_padded)
+ t_495 = F.relu(t_494)
+ t_496 = self.n_Conv_60(t_495)
+ t_497 = torch.add(t_496, t_491)
+ t_498 = F.relu(t_497)
+ t_499 = self.n_Conv_61(t_498)
+ t_500 = F.relu(t_499)
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
+ t_501 = self.n_Conv_62(t_500_padded)
+ t_502 = F.relu(t_501)
+ t_503 = self.n_Conv_63(t_502)
+ t_504 = torch.add(t_503, t_498)
+ t_505 = F.relu(t_504)
+ t_506 = self.n_Conv_64(t_505)
+ t_507 = F.relu(t_506)
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
+ t_508 = self.n_Conv_65(t_507_padded)
+ t_509 = F.relu(t_508)
+ t_510 = self.n_Conv_66(t_509)
+ t_511 = torch.add(t_510, t_505)
+ t_512 = F.relu(t_511)
+ t_513 = self.n_Conv_67(t_512)
+ t_514 = F.relu(t_513)
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
+ t_515 = self.n_Conv_68(t_514_padded)
+ t_516 = F.relu(t_515)
+ t_517 = self.n_Conv_69(t_516)
+ t_518 = torch.add(t_517, t_512)
+ t_519 = F.relu(t_518)
+ t_520 = self.n_Conv_70(t_519)
+ t_521 = F.relu(t_520)
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
+ t_522 = self.n_Conv_71(t_521_padded)
+ t_523 = F.relu(t_522)
+ t_524 = self.n_Conv_72(t_523)
+ t_525 = torch.add(t_524, t_519)
+ t_526 = F.relu(t_525)
+ t_527 = self.n_Conv_73(t_526)
+ t_528 = F.relu(t_527)
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
+ t_529 = self.n_Conv_74(t_528_padded)
+ t_530 = F.relu(t_529)
+ t_531 = self.n_Conv_75(t_530)
+ t_532 = torch.add(t_531, t_526)
+ t_533 = F.relu(t_532)
+ t_534 = self.n_Conv_76(t_533)
+ t_535 = F.relu(t_534)
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
+ t_536 = self.n_Conv_77(t_535_padded)
+ t_537 = F.relu(t_536)
+ t_538 = self.n_Conv_78(t_537)
+ t_539 = torch.add(t_538, t_533)
+ t_540 = F.relu(t_539)
+ t_541 = self.n_Conv_79(t_540)
+ t_542 = F.relu(t_541)
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
+ t_543 = self.n_Conv_80(t_542_padded)
+ t_544 = F.relu(t_543)
+ t_545 = self.n_Conv_81(t_544)
+ t_546 = torch.add(t_545, t_540)
+ t_547 = F.relu(t_546)
+ t_548 = self.n_Conv_82(t_547)
+ t_549 = F.relu(t_548)
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
+ t_550 = self.n_Conv_83(t_549_padded)
+ t_551 = F.relu(t_550)
+ t_552 = self.n_Conv_84(t_551)
+ t_553 = torch.add(t_552, t_547)
+ t_554 = F.relu(t_553)
+ t_555 = self.n_Conv_85(t_554)
+ t_556 = F.relu(t_555)
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
+ t_557 = self.n_Conv_86(t_556_padded)
+ t_558 = F.relu(t_557)
+ t_559 = self.n_Conv_87(t_558)
+ t_560 = torch.add(t_559, t_554)
+ t_561 = F.relu(t_560)
+ t_562 = self.n_Conv_88(t_561)
+ t_563 = F.relu(t_562)
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
+ t_564 = self.n_Conv_89(t_563_padded)
+ t_565 = F.relu(t_564)
+ t_566 = self.n_Conv_90(t_565)
+ t_567 = torch.add(t_566, t_561)
+ t_568 = F.relu(t_567)
+ t_569 = self.n_Conv_91(t_568)
+ t_570 = F.relu(t_569)
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
+ t_571 = self.n_Conv_92(t_570_padded)
+ t_572 = F.relu(t_571)
+ t_573 = self.n_Conv_93(t_572)
+ t_574 = torch.add(t_573, t_568)
+ t_575 = F.relu(t_574)
+ t_576 = self.n_Conv_94(t_575)
+ t_577 = F.relu(t_576)
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
+ t_578 = self.n_Conv_95(t_577_padded)
+ t_579 = F.relu(t_578)
+ t_580 = self.n_Conv_96(t_579)
+ t_581 = torch.add(t_580, t_575)
+ t_582 = F.relu(t_581)
+ t_583 = self.n_Conv_97(t_582)
+ t_584 = F.relu(t_583)
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
+ t_585 = self.n_Conv_98(t_584_padded)
+ t_586 = F.relu(t_585)
+ t_587 = self.n_Conv_99(t_586)
+ t_588 = self.n_Conv_100(t_582)
+ t_589 = torch.add(t_587, t_588)
+ t_590 = F.relu(t_589)
+ t_591 = self.n_Conv_101(t_590)
+ t_592 = F.relu(t_591)
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
+ t_593 = self.n_Conv_102(t_592_padded)
+ t_594 = F.relu(t_593)
+ t_595 = self.n_Conv_103(t_594)
+ t_596 = torch.add(t_595, t_590)
+ t_597 = F.relu(t_596)
+ t_598 = self.n_Conv_104(t_597)
+ t_599 = F.relu(t_598)
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
+ t_600 = self.n_Conv_105(t_599_padded)
+ t_601 = F.relu(t_600)
+ t_602 = self.n_Conv_106(t_601)
+ t_603 = torch.add(t_602, t_597)
+ t_604 = F.relu(t_603)
+ t_605 = self.n_Conv_107(t_604)
+ t_606 = F.relu(t_605)
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
+ t_607 = self.n_Conv_108(t_606_padded)
+ t_608 = F.relu(t_607)
+ t_609 = self.n_Conv_109(t_608)
+ t_610 = torch.add(t_609, t_604)
+ t_611 = F.relu(t_610)
+ t_612 = self.n_Conv_110(t_611)
+ t_613 = F.relu(t_612)
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
+ t_614 = self.n_Conv_111(t_613_padded)
+ t_615 = F.relu(t_614)
+ t_616 = self.n_Conv_112(t_615)
+ t_617 = torch.add(t_616, t_611)
+ t_618 = F.relu(t_617)
+ t_619 = self.n_Conv_113(t_618)
+ t_620 = F.relu(t_619)
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
+ t_621 = self.n_Conv_114(t_620_padded)
+ t_622 = F.relu(t_621)
+ t_623 = self.n_Conv_115(t_622)
+ t_624 = torch.add(t_623, t_618)
+ t_625 = F.relu(t_624)
+ t_626 = self.n_Conv_116(t_625)
+ t_627 = F.relu(t_626)
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
+ t_628 = self.n_Conv_117(t_627_padded)
+ t_629 = F.relu(t_628)
+ t_630 = self.n_Conv_118(t_629)
+ t_631 = torch.add(t_630, t_625)
+ t_632 = F.relu(t_631)
+ t_633 = self.n_Conv_119(t_632)
+ t_634 = F.relu(t_633)
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
+ t_635 = self.n_Conv_120(t_634_padded)
+ t_636 = F.relu(t_635)
+ t_637 = self.n_Conv_121(t_636)
+ t_638 = torch.add(t_637, t_632)
+ t_639 = F.relu(t_638)
+ t_640 = self.n_Conv_122(t_639)
+ t_641 = F.relu(t_640)
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
+ t_642 = self.n_Conv_123(t_641_padded)
+ t_643 = F.relu(t_642)
+ t_644 = self.n_Conv_124(t_643)
+ t_645 = torch.add(t_644, t_639)
+ t_646 = F.relu(t_645)
+ t_647 = self.n_Conv_125(t_646)
+ t_648 = F.relu(t_647)
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
+ t_649 = self.n_Conv_126(t_648_padded)
+ t_650 = F.relu(t_649)
+ t_651 = self.n_Conv_127(t_650)
+ t_652 = torch.add(t_651, t_646)
+ t_653 = F.relu(t_652)
+ t_654 = self.n_Conv_128(t_653)
+ t_655 = F.relu(t_654)
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
+ t_656 = self.n_Conv_129(t_655_padded)
+ t_657 = F.relu(t_656)
+ t_658 = self.n_Conv_130(t_657)
+ t_659 = torch.add(t_658, t_653)
+ t_660 = F.relu(t_659)
+ t_661 = self.n_Conv_131(t_660)
+ t_662 = F.relu(t_661)
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
+ t_663 = self.n_Conv_132(t_662_padded)
+ t_664 = F.relu(t_663)
+ t_665 = self.n_Conv_133(t_664)
+ t_666 = torch.add(t_665, t_660)
+ t_667 = F.relu(t_666)
+ t_668 = self.n_Conv_134(t_667)
+ t_669 = F.relu(t_668)
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
+ t_670 = self.n_Conv_135(t_669_padded)
+ t_671 = F.relu(t_670)
+ t_672 = self.n_Conv_136(t_671)
+ t_673 = torch.add(t_672, t_667)
+ t_674 = F.relu(t_673)
+ t_675 = self.n_Conv_137(t_674)
+ t_676 = F.relu(t_675)
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
+ t_677 = self.n_Conv_138(t_676_padded)
+ t_678 = F.relu(t_677)
+ t_679 = self.n_Conv_139(t_678)
+ t_680 = torch.add(t_679, t_674)
+ t_681 = F.relu(t_680)
+ t_682 = self.n_Conv_140(t_681)
+ t_683 = F.relu(t_682)
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
+ t_684 = self.n_Conv_141(t_683_padded)
+ t_685 = F.relu(t_684)
+ t_686 = self.n_Conv_142(t_685)
+ t_687 = torch.add(t_686, t_681)
+ t_688 = F.relu(t_687)
+ t_689 = self.n_Conv_143(t_688)
+ t_690 = F.relu(t_689)
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
+ t_691 = self.n_Conv_144(t_690_padded)
+ t_692 = F.relu(t_691)
+ t_693 = self.n_Conv_145(t_692)
+ t_694 = torch.add(t_693, t_688)
+ t_695 = F.relu(t_694)
+ t_696 = self.n_Conv_146(t_695)
+ t_697 = F.relu(t_696)
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
+ t_698 = self.n_Conv_147(t_697_padded)
+ t_699 = F.relu(t_698)
+ t_700 = self.n_Conv_148(t_699)
+ t_701 = torch.add(t_700, t_695)
+ t_702 = F.relu(t_701)
+ t_703 = self.n_Conv_149(t_702)
+ t_704 = F.relu(t_703)
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
+ t_705 = self.n_Conv_150(t_704_padded)
+ t_706 = F.relu(t_705)
+ t_707 = self.n_Conv_151(t_706)
+ t_708 = torch.add(t_707, t_702)
+ t_709 = F.relu(t_708)
+ t_710 = self.n_Conv_152(t_709)
+ t_711 = F.relu(t_710)
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
+ t_712 = self.n_Conv_153(t_711_padded)
+ t_713 = F.relu(t_712)
+ t_714 = self.n_Conv_154(t_713)
+ t_715 = torch.add(t_714, t_709)
+ t_716 = F.relu(t_715)
+ t_717 = self.n_Conv_155(t_716)
+ t_718 = F.relu(t_717)
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
+ t_719 = self.n_Conv_156(t_718_padded)
+ t_720 = F.relu(t_719)
+ t_721 = self.n_Conv_157(t_720)
+ t_722 = torch.add(t_721, t_716)
+ t_723 = F.relu(t_722)
+ t_724 = self.n_Conv_158(t_723)
+ t_725 = self.n_Conv_159(t_723)
+ t_726 = F.relu(t_725)
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
+ t_727 = self.n_Conv_160(t_726_padded)
+ t_728 = F.relu(t_727)
+ t_729 = self.n_Conv_161(t_728)
+ t_730 = torch.add(t_729, t_724)
+ t_731 = F.relu(t_730)
+ t_732 = self.n_Conv_162(t_731)
+ t_733 = F.relu(t_732)
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
+ t_734 = self.n_Conv_163(t_733_padded)
+ t_735 = F.relu(t_734)
+ t_736 = self.n_Conv_164(t_735)
+ t_737 = torch.add(t_736, t_731)
+ t_738 = F.relu(t_737)
+ t_739 = self.n_Conv_165(t_738)
+ t_740 = F.relu(t_739)
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
+ t_741 = self.n_Conv_166(t_740_padded)
+ t_742 = F.relu(t_741)
+ t_743 = self.n_Conv_167(t_742)
+ t_744 = torch.add(t_743, t_738)
+ t_745 = F.relu(t_744)
+ t_746 = self.n_Conv_168(t_745)
+ t_747 = self.n_Conv_169(t_745)
+ t_748 = F.relu(t_747)
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
+ t_749 = self.n_Conv_170(t_748_padded)
+ t_750 = F.relu(t_749)
+ t_751 = self.n_Conv_171(t_750)
+ t_752 = torch.add(t_751, t_746)
+ t_753 = F.relu(t_752)
+ t_754 = self.n_Conv_172(t_753)
+ t_755 = F.relu(t_754)
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
+ t_756 = self.n_Conv_173(t_755_padded)
+ t_757 = F.relu(t_756)
+ t_758 = self.n_Conv_174(t_757)
+ t_759 = torch.add(t_758, t_753)
+ t_760 = F.relu(t_759)
+ t_761 = self.n_Conv_175(t_760)
+ t_762 = F.relu(t_761)
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
+ t_763 = self.n_Conv_176(t_762_padded)
+ t_764 = F.relu(t_763)
+ t_765 = self.n_Conv_177(t_764)
+ t_766 = torch.add(t_765, t_760)
+ t_767 = F.relu(t_766)
+ t_768 = self.n_Conv_178(t_767)
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
+ t_770 = torch.squeeze(t_769, 3)
+ t_770 = torch.squeeze(t_770, 2)
+ t_771 = torch.sigmoid(t_770)
+ return t_771
+
+ def load_state_dict(self, state_dict, **kwargs):
+ self.tags = state_dict.get('tags', [])
+
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
+
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..800510b7 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,62 +1,96 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
+from packaging import version
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
-has_mps = getattr(torch, 'has_mps', False)
-cpu = torch.device("cpu")
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def has_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+
+
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]:
+ return args[x + 1]
+
+ return None
+
+
+def get_cuda_device_string():
+ from modules import shared
+
+ if shared.cmd_opts.device_id is not None:
+ return f"cuda:{shared.cmd_opts.device_id}"
+
+ return "cuda"
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ return torch.device(get_cuda_device_string())
- if has_mps:
+ if has_mps():
return torch.device("mps")
return cpu
+def get_device_for(task):
+ from modules import shared
+
+ if task in shared.cmd_opts.use_cpu:
+ return cpu
+
+ return get_optimal_device()
+
+
def torch_gc():
if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
def enable_tf32():
if torch.cuda.is_available():
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ torch.backends.cudnn.benchmark = True
+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
+
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
-def randn(seed, shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
- if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- generator.manual_seed(seed)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
+def randn(seed, shape):
torch.manual_seed(seed)
+ if device.type == 'mps':
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
-
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
@@ -70,3 +104,37 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+orig_tensor_to = torch.Tensor.to
+def tensor_to_fix(self, *args, **kwargs):
+ if self.device.type != 'mps' and \
+ ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
+ (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
+ self = self.contiguous()
+ return orig_tensor_to(self, *args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+orig_layer_norm = torch.nn.functional.layer_norm
+def layer_norm_fix(*args, **kwargs):
+ if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
+ args = list(args)
+ args[0] = args[0].contiguous()
+ return orig_layer_norm(*args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+orig_tensor_numpy = torch.Tensor.numpy
+def numpy_fix(self, *args, **kwargs):
+ if self.requires_grad:
+ self = self.detach()
+ return orig_tensor_numpy(self, *args, **kwargs)
+
+
+# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
+ torch.Tensor.to = tensor_to_fix
+ torch.nn.functional.layer_norm = layer_norm_fix
+ torch.Tensor.numpy = numpy_fix
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 46ad0da3..9a9c38f1 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -11,62 +11,118 @@ from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-def fix_model_layers(crt_model, pretrained_net):
- # this code is adapted from https://github.com/xinntao/ESRGAN
- if 'conv_first.weight' in pretrained_net:
- return pretrained_net
-
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- return crt_net
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
+
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
@@ -109,20 +165,39 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise Exception("The file is not a recognized ESRGAN model.")
+
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
- pretrained_net = fix_model_layers(crt_model, pretrained_net)
- crt_model.load_state_dict(pretrained_net)
- crt_model.eval()
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
- return crt_model
+ return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index e413d36e..bc9ceb2a 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -1,80 +1,463 @@
-# this file is taken from https://github.com/xinntao/ESRGAN
+# this file is adapted from https://github.com/victorca25/iNNfer
+import math
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
+####################
+# RRDBNet Generator
+####################
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
- # initialization
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
- def __init__(self, nf, gc=32):
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
return out * 0.2 + x
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
+####################
+# ESRGANplus
+####################
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 00000000..b522125c
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,99 @@
+import os
+import sys
+import traceback
+
+import git
+
+from modules import paths, shared
+
+extensions = []
+extensions_dir = os.path.join(paths.script_path, "extensions")
+extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
+
+
+def active():
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True, is_builtin=False):
+ self.name = name
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+ self.is_builtin = is_builtin
+
+ repo = None
+ try:
+ if os.path.exists(os.path.join(path, ".git")):
+ repo = git.Repo(path)
+ except Exception:
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ try:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+ except Exception:
+ self.remote = None
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
+
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return res
+
+ def check_updates(self):
+ repo = git.Repo(self.path)
+ for fetch in repo.remote().fetch("--dry-run"):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "behind"
+ return
+
+ self.can_update = False
+ self.status = "latest"
+
+ def fetch_and_reset_hard(self):
+ repo = git.Repo(self.path)
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
+ repo.git.fetch('--all')
+ repo.git.reset('--hard', 'origin')
+
+
+def list_extensions():
+ extensions.clear()
+
+ if not os.path.isdir(extensions_dir):
+ return
+
+ paths = []
+ for dirname in [extensions_dir, extensions_builtin_dir]:
+ if not os.path.isdir(dirname):
+ return
+
+ for extension_dirname in sorted(os.listdir(dirname)):
+ path = os.path.join(dirname, extension_dirname)
+ if not os.path.isdir(path):
+ continue
+
+ paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+
+ for dirname, path, is_builtin in paths:
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
+ extensions.append(extension)
+
diff --git a/modules/extras.py b/modules/extras.py
index b853fa5b..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
import math
import os
+import sys
+import traceback
import numpy as np
from PIL import Image
@@ -7,7 +10,11 @@ from PIL import Image
import torch
import tqdm
-from modules import processing, shared, images, devices, sd_models
+from typing import Callable, List, OrderedDict, Tuple
+from functools import partial
+from dataclasses import dataclass
+
+from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -15,19 +22,50 @@ import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr
+import safetensors.torch
+
+class LruCache(OrderedDict):
+ @dataclass(frozen=True)
+ class Key:
+ image_hash: int
+ info_hash: int
+ args_hash: int
+
+ @dataclass
+ class Value:
+ image: Image.Image
+ info: str
+
+ def __init__(self, max_size: int = 5, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._max_size = max_size
+ def get(self, key: LruCache.Key) -> LruCache.Value:
+ ret = super().get(key)
+ if ret is not None:
+ self.move_to_end(key) # Move to end of eviction list
+ return ret
-cached_images = {}
+ def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
+ self[key] = value
+ while len(self) > self._max_size:
+ self.popitem(last=False)
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+cached_images: LruCache = LruCache(max_size=5)
+
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
outputs = []
-
+
if extras_mode == 1:
#convert file to pillow image
for img in image_folder:
@@ -39,9 +77,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ image_list = shared.listfiles(input_dir)
for img in image_list:
- image = Image.open(img)
+ try:
+ image = Image.open(img)
+ except Exception:
+ continue
imageArr.append(image)
imageNameArr.append(img)
else:
@@ -53,80 +94,128 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
- existing_pnginfo = image.info or {}
+ # Extra operation definitions
- image = image.convert("RGB")
- info = ""
+ def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
- if gfpgan_visibility > 0:
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(image, res, gfpgan_visibility)
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
+ return (res, info)
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- image = res
+ def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
- if codeformer_visibility > 0:
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
+ if codeformer_visibility < 1.0:
+ res = Image.blend(image, res, codeformer_visibility)
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
+ return (res, info)
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- image = res
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
+ upscaler = shared.sd_upscalers[scaler_index]
+ res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
+ res = cropped
+ return res
+ def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
+ nonlocal upscaling_resize
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+ return (image, info)
+
+ @dataclass
+ class UpscaleParams:
+ upscaler_idx: int
+ blend_alpha: float
+
+ def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ blended_result: Image.Image = None
+ image_hash: str = hash(np.array(image.getdata()).tobytes())
+ for upscaler in params:
+ upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
+ upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+ cache_key = LruCache.Key(image_hash=image_hash,
+ info_hash=hash(info),
+ args_hash=hash(upscale_args))
+ cached_entry = cached_images.get(cache_key)
+ if cached_entry is None:
+ res = upscale(image, *upscale_args)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+ cached_images.put(cache_key, LruCache.Value(image=res, info=info))
+ else:
+ res, info = cached_entry.image, cached_entry.info
+
+ if blended_result is None:
+ blended_result = res
+ else:
+ blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
+ return (blended_result, info)
+
+ # Build a list of operations to run
+ facefix_ops: List[Callable] = []
+ facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
+ facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
+
+ upscale_ops: List[Callable] = []
+ upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
+
+ if upscaling_resize != 0:
+ step_params: List[UpscaleParams] = []
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
+
+ upscale_ops.append(partial(run_upscalers_blend, step_params))
+
+ extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
- resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
-
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
- c = cropped
- cached_images[key] = c
-
- return c
-
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
-
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
- image = res
+ shared.state.textinfo = f'Processing image {image_name}'
+
+ existing_pnginfo = image.info or {}
- while len(cached_images) > 2:
- del cached_images[next(iter(cached_images.keys()))]
+ image = image.convert("RGB")
+ info = ""
+ # Run each operation on each image
+ for op in extras_ops:
+ image, info = op(image, info)
- images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
- forced_filename=image_name if opts.use_original_name_batch else None)
+ if opts.use_original_name_batch and image_name is not None:
+ basename = os.path.splitext(os.path.basename(image_name))[0]
+ else:
+ basename = ''
- if opts.enable_pnginfo:
+ if opts.enable_pnginfo: # append info before save
image.info = existing_pnginfo
image.info["extras"] = info
+ if save_output:
+ # Add upscaler name as a suffix.
+ suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
+ # Add second upscaler if applicable.
+ if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
+ suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
+
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
+
if extras_mode != 2 or show_extras_results :
outputs.append(image)
@@ -134,30 +223,16 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), ''
+def clear_cache():
+ cached_images.clear()
+
def run_pnginfo(image):
if image is None:
return '', '', ''
- items = image.info
- geninfo = ''
-
- if "exif" in image.info:
- exif = piexif.load(image.info["exif"])
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
- try:
- exif_comment = piexif.helper.UserComment.load(exif_comment)
- except ValueError:
- exif_comment = exif_comment.decode('utf8', errors="ignore")
-
- items['exif comment'] = exif_comment
- geninfo = exif_comment
-
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
-
- geninfo = items.get('parameters', geninfo)
+ geninfo, items = images.read_info_from_image(image)
+ items = {**{'parameters': geninfo}, **items}
info = ''
for key, text in items.items():
@@ -175,7 +250,10 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -187,23 +265,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
-
- print(f"Loading {primary_model_info.filename}...")
- primary_model = torch.load(primary_model_info.filename, map_location='cpu')
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
-
- print(f"Loading {secondary_model_info.filename}...")
- secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
- theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
-
- if teritary_model_info is not None:
- print(f"Loading {teritary_model_info.filename}...")
- teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
- theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
- else:
- teritary_model = None
- theta_2 = None
+ tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
+ result_is_inpainting_model = False
theta_funcs = {
"Weighted sum": (None, weighted_sum),
@@ -211,9 +274,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
}
theta_func1, theta_func2 = theta_funcs[interp_method]
- print(f"Merging...")
+ if theta_func1 and not tertiary_model_info:
+ shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+ shared.state.end()
+ return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+
+ shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1:
+ print(f"Loading {tertiary_model_info.filename}...")
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
@@ -221,12 +294,33 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
- del theta_2, teritary_model
+ del theta_2
+
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
+ print(f"Loading {primary_model_info.filename}...")
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
+
+ print("Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
+ a = theta_0[key]
+ b = theta_1[key]
+
+ shared.state.textinfo = f'Merging layer {key}'
+ # this enables merging an inpainting model (A) with another one (B);
+ # where normal model would have 4 channels, for latenst space, inpainting model would
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
+ if a.shape[1] == 4 and b.shape[1] == 9:
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
- theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
+
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
+ result_is_inpainting_model = True
+ else:
+ theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
theta_0[key] = theta_0[key].half()
@@ -237,17 +331,35 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
+ del theta_1
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
- filename = filename if custom_name == '' else (custom_name + '.ckpt')
+ filename = \
+ primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
+ secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
+ interp_method.replace(" ", "_") + \
+ '-merged.' + \
+ ("inpainting." if result_is_inpainting_model else "") + \
+ checkpoint_format
+
+ filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
- torch.save(primary_model, output_modelname)
+
+ _, extension = os.path.splitext(output_modelname)
+ if extension.lower() == ".safetensors":
+ safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ else:
+ torch.save(theta_0, output_modelname)
sd_models.list_models()
- print(f"Checkpoint saved.")
+ print("Checkpoint saved.")
+ shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ shared.state.end()
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 98d24406..4baf4d9a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,14 +1,222 @@
+import base64
+import io
+import math
import os
import re
+from pathlib import Path
+
import gradio as gr
from modules.shared import script_path
-from modules import shared
+from modules import shared, ui_tempdir
+import tempfile
+from PIL import Image
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
+re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
+paste_fields = {}
+bind_list = []
+
+
+def reset():
+ paste_fields.clear()
+ bind_list.clear()
+
+
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
+
+def image_from_url_text(filedata):
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
+ filename = filedata["name"]
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
+
+ return Image.open(filename)
+
+ if type(filedata) == list:
+ if len(filedata) == 0:
+ return None
+
+ filedata = filedata[0]
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
+ image = Image.open(io.BytesIO(filedata))
+ return image
+
+
+def add_paste_fields(tabname, init_img, fields):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+
+ # backwards compatibility for existing extensions
+ import modules.ui
+ if tabname == 'txt2img':
+ modules.ui.txt2img_paste_fields = fields
+ elif tabname == 'img2img':
+ modules.ui.img2img_paste_fields = fields
+
+
+def integrate_settings_paste_fields(component_dict):
+ from modules import ui
+
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernet strength',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'inpainting_mask_weight': 'Conditional mask weight',
+ 'sd_model_checkpoint': 'Model hash',
+ 'eta_noise_seed_delta': 'ENSD',
+ 'initial_noise_multiplier': 'Noise multiplier',
+ }
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ for tabname, info in paste_fields.items():
+ if info["fields"] is not None:
+ info["fields"] += settings_paste_fields
+
+
+def create_buttons(tabs_list):
+ buttons = {}
+ for tab in tabs_list:
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
+ return buttons
+
+
+#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
+def bind_buttons(buttons, send_image, send_generate_info):
+ bind_list.append([buttons, send_image, send_generate_info])
+
+
+def send_image_and_dimensions(x):
+ if isinstance(x, Image.Image):
+ img = x
+ else:
+ img = image_from_url_text(x)
+
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+
+ return img, w, h
+
+
+def run_bind():
+ for buttons, source_image_component, send_generate_info in bind_list:
+ for tab in buttons:
+ button = buttons[tab]
+ destination_image_component = paste_fields[tab]["init_img"]
+ fields = paste_fields[tab]["fields"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if source_image_component and destination_image_component:
+ if isinstance(source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
+ else:
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
+
+ if send_generate_info and fields is not None:
+ if send_generate_info in paste_fields:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
+ )
+ else:
+ connect_paste(button, fields, send_generate_info)
+
+ button.click(
+ fn=None,
+ _js=f"switch_to_{tab}",
+ inputs=None,
+ outputs=None,
+ )
+
+
+def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
+
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
+
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
+ """
+ hypernet_name = hypernet_name.lower()
+ if hypernet_hash is not None:
+ # Try to match the hash in the name
+ for hypernet_key in shared.hypernetworks.keys():
+ result = re_hypernet_hash.search(hypernet_key)
+ if result is not None and result[1] == hypernet_hash:
+ return hypernet_key
+ else:
+ # Fall back to a hypernet with the same name
+ for hypernet_key in shared.hypernetworks.keys():
+ if hypernet_key.lower().startswith(hypernet_name):
+ return hypernet_key
+
+ return None
+
+
+def restore_old_hires_fix_params(res):
+ """for infotexts that specify old First pass size parameter, convert it into
+ width, height, and hr scale"""
+
+ firstpass_width = res.get('First pass size-1', None)
+ firstpass_height = res.get('First pass size-2', None)
+
+ if firstpass_width is None or firstpass_height is None:
+ return
+
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+ width = int(res.get("Size-1", 512))
+ height = int(res.get("Size-2", 512))
+
+ if firstpass_width == 0 or firstpass_height == 0:
+ # old algorithm for auto-calculating first pass size
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ firstpass_width = math.ceil(scale * width / 64) * 64
+ firstpass_height = math.ceil(scale * height / 64) * 64
+
+ hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
+
+ res['Size-1'] = firstpass_width
+ res['Size-2'] = firstpass_height
+ res['Hires upscale'] = hr_scale
def parse_generation_parameters(x: str):
@@ -45,8 +253,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- res["Prompt"] = prompt
- res["Negative prompt"] = negative_prompt
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
m = re_imagesize.match(v)
@@ -56,10 +264,24 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
res[k] = v
+ # Missing CLIP skip means it was set to 1 (the default)
+ if "Clip skip" not in res:
+ res["Clip skip"] = "1"
+
+ if "Hypernet strength" not in res:
+ res["Hypernet strength"] = "1"
+
+ if "Hypernet" in res:
+ hypernet_name = res["Hypernet"]
+ hypernet_hash = res.get("Hypernet hash", None)
+ res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+
+ restore_old_hires_fix_params(res)
+
return res
-def connect_paste(button, paste_fields, input_comp, js=None):
+def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
@@ -83,7 +305,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
- val = valtype(v)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
@@ -92,7 +319,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click(
fn=paste_func,
- _js=js,
+ _js=jsfunc,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
+
+
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index a9452dce..1e2dbc32 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -36,7 +36,9 @@ def gfpgann():
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+ if hasattr(facexlib.detection.retinaface, 'device'):
+ facexlib.detection.retinaface.device = devices.device_gfpgan
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model
return model
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,40 +1,72 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
-
-import torch
+import inspect
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
+import modules.textual_inversion.dataset
import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
+
+from collections import defaultdict, deque
+from statistics import stdev, mean
+
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
-
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
+ activation_dict = {
+ "linear": torch.nn.Identity,
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
+ }
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__()
- assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Add dropout except last layer
+ if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
+ linears.append(torch.nn.Dropout(p=0.3))
+
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
@@ -42,9 +74,25 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
-
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ w, b = layer.weight.data, layer.bias.data
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+ normal_(w, mean=0.0, std=0.01)
+ normal_(b, mean=0.0, std=0)
+ elif weight_init == 'XavierUniform':
+ xavier_uniform_(w)
+ zeros_(b)
+ elif weight_init == 'XavierNormal':
+ xavier_normal_(w)
+ zeros_(b)
+ elif weight_init == 'KaimingUniform':
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ elif weight_init == 'KaimingNormal':
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ else:
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device)
def fix_old_state_dict(self, state_dict):
@@ -69,7 +117,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- layer_structure += [layer.weight, layer.bias]
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -81,7 +130,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
@@ -89,26 +138,48 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
+ self.activation_func = activation_func
+ self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
+ self.activate_output = activate_output
+ self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
+ self.eval_mode()
def weights(self):
res = []
+ for k, layers in self.layers.items():
+ for layer in layers:
+ res += layer.parameters()
+ return res
+ def train_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += layer.trainables()
+ for param in layer.parameters():
+ param.requires_grad = True
- return res
+ def eval_mode(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
def save(self, filename):
state_dict = {}
+ optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -116,11 +187,23 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['weight_initialization'] = self.weight_init
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+ state_dict['activate_output'] = self.activate_output
+ state_dict['last_layer_dropout'] = self.last_layer_dropout
+
+ if self.optimizer_name is not None:
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@@ -130,13 +213,38 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ print(self.layer_structure)
+ self.activation_func = state_dict.get('activation_func', None)
+ print(f"Activation function is {self.activation_func}")
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
+ print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ print(f"Layer norm is set to {self.add_layer_norm}")
+ self.use_dropout = state_dict.get('use_dropout', False)
+ print(f"Dropout usage is set to {self.use_dropout}" )
+ self.activate_output = state_dict.get('activate_output', True)
+ print(f"Activate last layer is set to {self.activate_output}")
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+ print(f"Optimizer name is {self.optimizer_name}")
+ if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
+ if self.optimizer_state_dict:
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -147,15 +255,18 @@ class Hypernetwork:
def list_hypernetworks(path):
res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
+ # Prevent a hypothetical "None.pt" from being listed.
+ if name != "None":
+ res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
- if path is not None:
+ # Prevent any file named "None.pt" from being loaded.
+ if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
@@ -166,7 +277,7 @@ def load_hypernetwork(filename):
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
+ print("Unloading hypernetwork")
shared.loaded_hypernetwork = None
@@ -240,16 +351,77 @@ def stack_conds(conds):
return torch.stack(conds)
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert hypernetwork_name, 'hypernetwork not selected'
+def statistics(data):
+ if len(data) < 2:
+ std = 0
+ else:
+ std = stdev(data)
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+ recent_data = data[-32:]
+ if len(recent_data) < 2:
+ std = 0
+ else:
+ std = stdev(recent_data)
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(list(loss_info[key]))
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ weight_init=weight_init,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return fn
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images
+
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@@ -267,126 +439,229 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
images_dir = None
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
+ checkpoint = sd_models.select_checkpoint()
- losses = torch.zeros((32,))
+ initial_step = hypernetwork.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
- last_saved_file = "<none>"
- last_saved_image = "<none>"
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
-
- pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entries in pbar:
- hypernetwork.step = i + ititial_step
-
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
- del c
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- mean_loss = losses.mean()
- if torch.isnan(mean_loss):
- raise RuntimeError("Loss diverged.")
- pbar.set_description(f"loss: {mean_loss:.7f}")
-
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
- hypernetwork.save(last_saved_file)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
-
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
+ pin_memory = shared.opts.pin_memory
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+
+ latent_sampling_method = ds.latent_sampling_method
- preview_text = p.prompt
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ weights = hypernetwork.weights()
+ hypernetwork.train_mode()
- if image is not None:
- shared.state.current_image = image
- image.save(last_saved_image)
- last_saved_image += f", prompt: {preview_text}"
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if hypernetwork.optimizer_name in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ optimizer_name = hypernetwork.optimizer_name
+ else:
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
- shared.state.job_no = hypernetwork.step
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
+ # size = len(ds.indexes)
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ # losses = torch.zeros((size,))
+ # previous_mean_losses = [0]
+ # previous_mean_loss = 0
+ # print("Mean loss of {} elements".format(size))
+
+ steps_without_grad = 0
- shared.state.textinfo = f"""
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if tag_drop_out != 0 or shuffle_tags:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ else:
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
+ loss = shared.sd_model(x, c)[0] / gradient_step
+ del x
+ del c
+
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
+ # scaler.unscale_(optimizer)
+ # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+ # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
+ # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+ scaler.step(optimizer)
+ scaler.update()
+ hypernetwork.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = hypernetwork.step + 1
+
+ epoch_num = hypernetwork.step // steps_per_epoch
+ epoch_step = hypernetwork.step % steps_per_epoch
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ hypernetwork.eval_mode()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+ hypernetwork.train_mode()
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
<p>
-Loss: {mean_loss:.7f}<br/>
-Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(entries[0].cond_text)}<br/>
-Last saved embedding: {html.escape(last_saved_file)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
+Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ pbar.leave = False
+ pbar.close()
+ hypernetwork.eval_mode()
+ #report_statistics(loss_dict)
- checkpoint = sd_models.select_checkpoint()
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
- hypernetwork.sd_checkpoint = checkpoint.hash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- hypernetwork.save(filename)
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename
-
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e7f9e593 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,31 +3,16 @@ import os
import re
import gradio as gr
+import modules.hypernetworks.hypernetwork
+from modules import devices, sd_hijack, shared
-import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
-from modules.hypernetworks import hypernetwork
+not_available = ["hardswish", "multiheadattention"]
+keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- if type(layer_structure) == str:
- layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
-
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
- name=name,
- enable_sizes=[int(x) for x in enable_sizes],
- layer_structure=layer_structure,
- add_layer_norm=add_layer_norm,
- )
- hypernet.save(fn)
-
- shared.reload_hypernetworks()
-
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py
index b9589563..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,8 @@
import datetime
+import sys
+import traceback
+
+import pytz
import io
import math
import os
@@ -11,8 +15,9 @@ import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
+import json
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -34,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows)
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+
w, h = imgs[0].size
- grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
- for i, img in enumerate(imgs):
- grid.paste(img, box=(i % cols * w, i // cols * h))
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid
@@ -131,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
lines.append(word)
return lines
- def draw_texts(drawing, draw_x, draw_y, lines):
+ def get_font(fontsize):
+ try:
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
+ except Exception:
+ return ImageFont.truetype(Roboto, fontsize)
+
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines):
+ fnt = initial_fnt
+ fontsize = initial_fontsize
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
+ fontsize -= 1
+ fnt = get_font(fontsize)
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
@@ -143,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
fontsize = (width + height) // 25
line_spacing = fontsize // 2
- try:
- fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
- except Exception:
- fnt = ImageFont.truetype(Roboto, fontsize)
+ fnt = get_font(fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
@@ -173,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
for line in texts:
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+ line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
@@ -189,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2
- draw_texts(d, x, y, hor_texts[col])
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
- draw_texts(d, x, y, ver_texts[row])
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result
@@ -213,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
-def resize_image(resize_mode, im, width, height):
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: Resize the image to the specified width and height.
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
def resize(im, w, h):
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
- assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
@@ -273,10 +306,15 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
def sanitize_filename_part(text, replace_spaces=True):
+ if text is None:
+ return None
+
if replace_spaces:
text = text.replace(' ', '_')
@@ -286,49 +324,105 @@ def sanitize_filename_part(text, replace_spaces=True):
return text
-def apply_filename_pattern(x, p, seed, prompt):
- max_prompt_words = opts.directories_max_prompt_words
-
- if seed is not None:
- x = x.replace("[seed]", str(seed))
+class FilenameGenerator:
+ replacements = {
+ 'seed': lambda self: self.seed if self.seed is not None else '',
+ 'steps': lambda self: self.p and self.p.steps,
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
+ 'width': lambda self: self.image.width,
+ 'height': lambda self: self.image.height,
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
+ 'prompt_words': lambda self: self.prompt_words(),
+ }
+ default_time_format = '%Y%m%d%H%M%S'
+
+ def __init__(self, p, seed, prompt, image):
+ self.p = p
+ self.seed = seed
+ self.prompt = prompt
+ self.image = image
+
+ def prompt_no_style(self):
+ if self.p is None or self.prompt is None:
+ return None
+
+ prompt_no_style = self.prompt
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
+ if len(style) > 0:
+ for part in style.split("{prompt}"):
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
+
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
+
+ def prompt_words(self):
+ words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ if len(words) == 0:
+ words = ["empty"]
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
+
+ def datetime(self, *args):
+ time_datetime = datetime.datetime.now()
+
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ try:
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ except pytz.exceptions.UnknownTimeZoneError as _:
+ time_zone = None
+
+ time_zone_time = time_datetime.astimezone(time_zone)
+ try:
+ formatted_time = time_zone_time.strftime(time_format)
+ except (ValueError, TypeError) as _:
+ formatted_time = time_zone_time.strftime(self.default_time_format)
+
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
+
+ def apply(self, x):
+ res = ''
+
+ for m in re_pattern.finditer(x):
+ text, pattern = m.groups()
+ res += text
+
+ if pattern is None:
+ continue
- if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+ pattern_args = []
+ while True:
+ m = re_pattern_arg.match(pattern)
+ if m is None:
+ break
- x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
- x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
- x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
+ pattern, arg = m.groups()
+ pattern_args.insert(0, arg)
- # Apply [prompt] at last. Because it may contain any replacement word.^M
- if prompt is not None:
- x = x.replace("[prompt]", sanitize_filename_part(prompt))
- if "[prompt_no_styles]" in x:
- prompt_no_style = prompt
- for style in shared.prompt_styles.get_style_prompts(p.styles):
- if len(style) > 0:
- style_parts = [y for y in style.split("{prompt}")]
- for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
- prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
- x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
+ fun = self.replacements.get(pattern.lower())
+ if fun is not None:
+ try:
+ replacement = fun(self, *pattern_args)
+ except Exception:
+ replacement = None
+ print(f"Error adding [{pattern}] to filename", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
- x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
- if "[prompt_words]" in x:
- words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
- if len(words) == 0:
- words = ["empty"]
- x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+ if replacement is not None:
+ res += str(replacement)
+ continue
- if cmd_opts.hide_ui_dir_config:
- x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
+ res += f'[{pattern}]'
- return x
+ return res
def get_next_sequence_number(path, basename):
@@ -354,7 +448,7 @@ def get_next_sequence_number(path, basename):
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
- '''Save an image.
+ """Save an image.
Args:
image (`PIL.Image`):
@@ -363,7 +457,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
basename (`str`):
The base filename which will be applied to `filename pattern`.
- seed, prompt, short_filename,
+ seed, prompt, short_filename,
extension (`str`):
Image file extension, default is `png`.
pngsectionname (`str`):
@@ -385,66 +479,94 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The full path of the saved imaged.
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
- '''
- if short_filename or prompt is None or seed is None:
- file_decoration = ""
- elif opts.save_to_dirs:
- file_decoration = opts.samples_filename_pattern or "[seed]"
- else:
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
-
- if file_decoration != "":
- file_decoration = "-" + file_decoration.lower()
-
- file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
-
- if extension == 'png' and opts.enable_pnginfo and info is not None:
- pnginfo = PngImagePlugin.PngInfo()
-
- if existing_info is not None:
- for k, v in existing_info.items():
- pnginfo.add_text(k, str(v))
-
- pnginfo.add_text(pnginfo_section_name, info)
- else:
- pnginfo = None
+ """
+ namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
if forced_filename is None:
- basecount = get_next_sequence_number(path, basename)
- fullfn = "a.png"
- fullfn_without_extension = "a"
- for i in range(500):
- fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
- if not os.path.exists(fullfn):
- break
+ if short_filename or seed is None:
+ file_decoration = ""
+ elif opts.save_to_dirs:
+ file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+
+ add_number = opts.save_images_add_number or file_decoration == ''
+
+ if file_decoration != "" and add_number:
+ file_decoration = "-" + file_decoration
+
+ file_decoration = namegen.apply(file_decoration) + suffix
+
+ if add_number:
+ basecount = get_next_sequence_number(path, basename)
+ fullfn = None
+ for i in range(500):
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+ if not os.path.exists(fullfn):
+ break
+ else:
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
- fullfn_without_extension = os.path.join(path, forced_filename)
-
- def exif_bytes():
- return piexif.dump({
- "Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
- },
- })
-
- if extension.lower() in ("jpg", "jpeg", "webp"):
- image.save(fullfn, quality=opts.jpeg_quality)
- if opts.enable_pnginfo and info is not None:
- piexif.insert(exif_bytes(), fullfn)
- else:
- image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
+
+ pnginfo = existing_info or {}
+ if info is not None:
+ pnginfo[pnginfo_section_name] = info
+
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
+ script_callbacks.before_image_saved_callback(params)
+
+ image = params.image
+ fullfn = params.filename
+ info = params.pnginfo.get(pnginfo_section_name, None)
+
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
+ # save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ temp_file_path = filename_without_extension + ".tmp"
+ image_format = Image.registered_extensions()[extension]
+
+ if extension.lower() == '.png':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ if opts.enable_pnginfo:
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
+
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image_to_save.mode == 'RGBA':
+ image_to_save = image_to_save.convert("RGB")
+
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+
+ if opts.enable_pnginfo and info is not None:
+ exif_bytes = piexif.dump({
+ "Exif": {
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
+ },
+ })
+
+ piexif.insert(exif_bytes, temp_file_path)
+ else:
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+
+ # atomically rename the file with correct extension
+ os.replace(temp_file_path, filename_without_extension + extension)
+
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
+ _atomically_save_image(image, fullfn_without_extension, extension)
+
+ image.already_saved_as = fullfn
target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length
@@ -456,9 +578,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif oversize:
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
- image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
- if opts.enable_pnginfo and info is not None:
- piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
@@ -467,13 +587,50 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
+ script_callbacks.image_saved_callback(params)
+
return fullfn, txt_fullfn
+def read_info_from_image(image):
+ items = image.info or {}
+
+ geninfo = items.pop('parameters', None)
+
+ if "exif" in items:
+ exif = piexif.load(items["exif"])
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
+ try:
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
+ except ValueError:
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
+
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
+
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration']:
+ items.pop(field, None)
+
+ if items.get("Software", None) == "NovelAI":
+ try:
+ json_info = json.loads(items["Comment"])
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
+
+ geninfo = f"""{items["Description"]}
+Negative prompt: {json_info["uc"]}
+Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
+ except Exception:
+ print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return geninfo, items
+
+
def image_data(data):
try:
image = Image.open(io.BytesIO(data))
- textinfo = image.text["parameters"]
+ textinfo, _ = read_info_from_image(image)
return textinfo, None
except Exception:
pass
@@ -487,3 +644,14 @@ def image_data(data):
pass
return '', None
+
+
+def flatten(img, bgcolor):
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
+
+ if img.mode == "RGBA":
+ background = Image.new('RGBA', img.size, bgcolor)
+ background.paste(img, mask=img)
+ img = background
+
+ return img.convert('RGB')
diff --git a/modules/images_history.py b/modules/images_history.py
deleted file mode 100644
index 46b23e56..00000000
--- a/modules/images_history.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import os
-import shutil
-import sys
-
-def traverse_all_files(output_dir, image_list, curr_dir=None):
- curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
- try:
- f_list = os.listdir(curr_path)
- except:
- if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
- image_list.append(curr_dir)
- return image_list
- for file in f_list:
- file = file if curr_dir is None else os.path.join(curr_dir, file)
- file_path = os.path.join(curr_path, file)
- if file[-4:] == ".txt":
- pass
- elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
- image_list.append(file)
- else:
- image_list = traverse_all_files(output_dir, image_list, file)
- return image_list
-
-
-def get_recent_images(dir_name, page_index, step, image_index, tabname):
- page_index = int(page_index)
- image_list = []
- if not os.path.exists(dir_name):
- pass
- elif os.path.isdir(dir_name):
- image_list = traverse_all_files(dir_name, image_list)
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
- else:
- print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
- num = 48 if tabname != "extras" else 12
- max_page_index = len(image_list) // num + 1
- page_index = max_page_index if page_index == -1 else page_index + step
- page_index = 1 if page_index < 1 else page_index
- page_index = max_page_index if page_index > max_page_index else page_index
- idx_frm = (page_index - 1) * num
- image_list = image_list[idx_frm:idx_frm + num]
- image_index = int(image_index)
- if image_index < 0 or image_index > len(image_list) - 1:
- current_file = None
- hidden = None
- else:
- current_file = image_list[int(image_index)]
- hidden = os.path.join(dir_name, current_file)
- return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
-
-
-def first_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, 1, 0, image_index, tabname)
-
-
-def end_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, -1, 0, image_index, tabname)
-
-
-def prev_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname)
-
-
-def next_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname)
-
-
-def page_index_change(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname)
-
-
-def show_image_info(num, image_path, filenames):
- # print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, os.path.join(image_path, file)
-
-
-def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
- if name == "":
- return filenames, delete_num
- else:
- delete_num = int(delete_num)
- index = list(filenames).index(name)
- i = 0
- new_file_list = []
- for name in filenames:
- if i >= index and i < index + delete_num:
- path = os.path.join(dir_name, name)
- if os.path.exists(path):
- print(f"Delete file {path}")
- os.remove(path)
- txt_file = os.path.splitext(path)[0] + ".txt"
- if os.path.exists(txt_file):
- os.remove(txt_file)
- else:
- print(f"Not exists file {path}")
- else:
- new_file_list.append(name)
- i += 1
- return new_file_list, 1
-
-
-def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- if opts.outdir_samples != "":
- dir_name = opts.outdir_samples
- elif tabname == "txt2img":
- dir_name = opts.outdir_txt2img_samples
- elif tabname == "img2img":
- dir_name = opts.outdir_img2img_samples
- elif tabname == "extras":
- dir_name = opts.outdir_extras_samples
- else:
- return
- with gr.Row():
- renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
- with gr.Row(elem_id=tabname + "_images_history"):
- with gr.Row():
- with gr.Column(scale=2):
- history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
- with gr.Row():
- delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
- delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- with gr.Column():
- with gr.Row():
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
- with gr.Row():
- with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(label="File Name", interactive=False)
- with gr.Row():
- # hiden items
-
- img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
- tabname_box = gr.Textbox(tabname, visible=False)
- image_index = gr.Textbox(value=-1, visible=False)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
- filenames = gr.State()
- hidden = gr.Image(type="pil", visible=False)
- info1 = gr.Textbox(visible=False)
- info2 = gr.Textbox(visible=False)
-
- # turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
-
- first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
-
- # other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
- img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
- hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-
- # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
- switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
- switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
-
-
-def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
- with gr.Blocks(analytics_enabled=False) as images_history:
- with gr.Tabs() as tabs:
- with gr.Tab("txt2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
- show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
- with gr.Tab("img2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
- with gr.Tab("extras history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
- return images_history
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..ca58b5d8 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,9 +4,9 @@ import sys
import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
-from modules import devices
+from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p)
- images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
break
img = Image.open(image)
+ # Use the EXIF orientation of photos taken by smartphones.
+ img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -53,27 +55,48 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}"
if not save_normally:
+ os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
if is_inpaint:
+ # Drawn mask
if mask_mode == 0:
- image = init_img_with_mask['image']
- mask = init_img_with_mask['mask']
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- image = image.convert('RGB')
+ is_mask_sketch = isinstance(init_img_with_mask, dict)
+ is_mask_paint = not is_mask_sketch
+ if is_mask_sketch:
+ # Sketch: mask iff. not transparent
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ else:
+ # Color-sketch: mask iff. painted over
+ image = init_img_with_mask
+ orig = init_img_with_mask_orig or init_img_with_mask
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+
+ image = image.convert("RGB")
+ # Uploaded mask
else:
image = init_img_inpaint
mask = init_mask_inpaint
+ # No mask
else:
image = init_img
mask = None
+ # Use the EXIF orientation of photos taken by smartphones.
+ if image is not None:
+ image = ImageOps.exif_transpose(image)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
@@ -89,7 +112,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_index=sampler_index,
+ sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
@@ -109,6 +132,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
@@ -125,6 +151,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
@@ -134,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/import_hook.py b/modules/import_hook.py
new file mode 100644
index 00000000..28c67dfa
--- /dev/null
+++ b/modules/import_hook.py
@@ -0,0 +1,5 @@
+import sys
+
+# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
+if "--xformers" not in "".join(sys.argv):
+ sys.modules["xformers"] = None
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 64b91eb4..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import sys
import traceback
@@ -11,10 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
@@ -28,9 +26,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
+ running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
@@ -45,7 +45,14 @@ class InterrogateModels:
def load_blip_model(self):
import models.blip
- blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "BLIP"),
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name='model_base_caption_capfilt_large.pth',
+ )
+
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
@@ -53,7 +60,11 @@ class InterrogateModels:
def load_clip_model(self):
import clip
- model, preprocess = clip.load(clip_model_name)
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
+ else:
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
+
model.eval()
model = model.to(devices.device_interrogate)
@@ -62,14 +73,14 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate)
@@ -124,8 +135,9 @@ class InterrogateModels:
return caption[0]
def interrogate(self, pil_image):
- res = None
-
+ res = ""
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -142,8 +154,7 @@ class InterrogateModels:
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
- with torch.no_grad(), precision_scope("cuda"):
+ with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)
@@ -162,10 +173,11 @@ class InterrogateModels:
res += ", " + match
except Exception:
- print(f"Error interrogating", file=sys.stderr)
+ print("Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>"
self.unload()
+ shared.state.end()
return res
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py
deleted file mode 100644
index 8c4db44a..00000000
--- a/modules/ldsr_model.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import os
-import sys
-import traceback
-
-from basicsr.utils.download_util import load_file_from_url
-
-from modules.upscaler import Upscaler, UpscalerData
-from modules.ldsr_model_arch import LDSR
-from modules import shared
-
-
-class UpscalerLDSR(Upscaler):
- def __init__(self, user_path):
- self.name = "LDSR"
- self.user_path = user_path
- self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
- self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
- super().__init__()
- scaler_data = UpscalerData("LDSR", None, self)
- self.scalers = [scaler_data]
-
- def load_model(self, path: str):
- # Remove incorrect project.yaml file if too big
- yaml_path = os.path.join(self.model_path, "project.yaml")
- old_model_path = os.path.join(self.model_path, "model.pth")
- new_model_path = os.path.join(self.model_path, "model.ckpt")
- if os.path.exists(yaml_path):
- statinfo = os.stat(yaml_path)
- if statinfo.st_size >= 10485760:
- print("Removing invalid LDSR YAML file.")
- os.remove(yaml_path)
- if os.path.exists(old_model_path):
- print("Renaming model from model.pth to model.ckpt")
- os.rename(old_model_path, new_model_path)
- model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="model.ckpt", progress=True)
- yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
- file_name="project.yaml", progress=True)
-
- try:
- return LDSR(model, yaml)
-
- except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- return None
-
- def do_upscale(self, img, path):
- ldsr = self.load_model(path)
- if ldsr is None:
- print("NO LDSR!")
- return img
- ddim_steps = shared.opts.ldsr_steps
- return ldsr.super_resolution(img, ddim_steps, self.scale)
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py
deleted file mode 100644
index 14db5076..00000000
--- a/modules/ldsr_model_arch.py
+++ /dev/null
@@ -1,222 +0,0 @@
-import gc
-import time
-import warnings
-
-import numpy as np
-import torch
-import torchvision
-from PIL import Image
-from einops import rearrange, repeat
-from omegaconf import OmegaConf
-
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.util import instantiate_from_config, ismap
-
-warnings.filterwarnings("ignore", category=UserWarning)
-
-
-# Create LDSR Class
-class LDSR:
- def load_model_from_config(self, half_attention):
- print(f"Loading model from {self.modelPath}")
- pl_sd = torch.load(self.modelPath, map_location="cpu")
- sd = pl_sd["state_dict"]
- config = OmegaConf.load(self.yamlPath)
- model = instantiate_from_config(config.model)
- model.load_state_dict(sd, strict=False)
- model.cuda()
- if half_attention:
- model = model.half()
-
- model.eval()
- return {"model": model}
-
- def __init__(self, model_path, yaml_path):
- self.modelPath = model_path
- self.yamlPath = yaml_path
-
- @staticmethod
- def run(model, selected_path, custom_steps, eta):
- example = get_cond(selected_path)
-
- n_runs = 1
- guider = None
- ckwargs = None
- ddim_use_x0_pred = False
- temperature = 1.
- eta = eta
- custom_shape = None
-
- height, width = example["image"].shape[1:3]
- split_input = height >= 128 and width >= 128
-
- if split_input:
- ks = 128
- stride = 64
- vqf = 4 #
- model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
- "vqf": vqf,
- "patch_distributed_vq": True,
- "tie_braker": False,
- "clip_max_weight": 0.5,
- "clip_min_weight": 0.01,
- "clip_max_tie_weight": 0.5,
- "clip_min_tie_weight": 0.01}
- else:
- if hasattr(model, "split_input_params"):
- delattr(model, "split_input_params")
-
- x_t = None
- logs = None
- for n in range(n_runs):
- if custom_shape is not None:
- x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
- x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
-
- logs = make_convolutional_sample(example, model,
- custom_steps=custom_steps,
- eta=eta, quantize_x0=False,
- custom_shape=custom_shape,
- temperature=temperature, noise_dropout=0.,
- corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
- ddim_use_x0_pred=ddim_use_x0_pred
- )
- return logs
-
- def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
- model = self.load_model_from_config(half_attention)
-
- # Run settings
- diffusion_steps = int(steps)
- eta = 1.0
-
- down_sample_method = 'Lanczos'
-
- gc.collect()
- torch.cuda.empty_cache()
-
- im_og = image
- width_og, height_og = im_og.size
- # If we can adjust the max upscale size, then the 4 below should be our variable
- down_sample_rate = target_scale / 4
- wd = width_og * down_sample_rate
- hd = height_og * down_sample_rate
- width_downsampled_pre = int(wd)
- height_downsampled_pre = int(hd)
-
- if down_sample_rate != 1:
- print(
- f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
- im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
- else:
- print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
- logs = self.run(model["model"], im_og, diffusion_steps, eta)
-
- sample = logs["sample"]
- sample = sample.detach().cpu()
- sample = torch.clamp(sample, -1., 1.)
- sample = (sample + 1.) / 2. * 255
- sample = sample.numpy().astype(np.uint8)
- sample = np.transpose(sample, (0, 2, 3, 1))
- a = Image.fromarray(sample[0])
-
- del model
- gc.collect()
- torch.cuda.empty_cache()
- return a
-
-
-def get_cond(selected_path):
- example = dict()
- up_f = 4
- c = selected_path.convert('RGB')
- c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
- c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
- antialias=True)
- c_up = rearrange(c_up, '1 c h w -> 1 h w c')
- c = rearrange(c, '1 c h w -> 1 h w c')
- c = 2. * c - 1.
-
- c = c.to(torch.device("cuda"))
- example["LR_image"] = c
- example["image"] = c_up
-
- return example
-
-
-@torch.no_grad()
-def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
- mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
- corrector_kwargs=None, x_t=None
- ):
- ddim = DDIMSampler(model)
- bs = shape[0]
- shape = shape[1:]
- print(f"Sampling with eta = {eta}; steps: {steps}")
- samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
- normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
- mask=mask, x0=x0, temperature=temperature, verbose=False,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs, x_t=x_t)
-
- return samples, intermediates
-
-
-@torch.no_grad()
-def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
- corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
- log = dict()
-
- z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
- return_first_stage_outputs=True,
- force_c_encode=not (hasattr(model, 'split_input_params')
- and model.cond_stage_key == 'coordinates_bbox'),
- return_original_cond=True)
-
- if custom_shape is not None:
- z = torch.randn(custom_shape)
- print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
-
- z0 = None
-
- log["input"] = x
- log["reconstruction"] = xrec
-
- if ismap(xc):
- log["original_conditioning"] = model.to_rgb(xc)
- if hasattr(model, 'cond_stage_key'):
- log[model.cond_stage_key] = model.to_rgb(xc)
-
- else:
- log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
- if model.cond_stage_model:
- log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
- if model.cond_stage_key == 'class_label':
- log[model.cond_stage_key] = xc[model.cond_stage_key]
-
- with model.ema_scope("Plotting"):
- t0 = time.time()
-
- sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
- eta=eta,
- quantize_x0=quantize_x0, mask=None, x0=z0,
- temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
- x_t=x_T)
- t1 = time.time()
-
- if ddim_use_x0_pred:
- sample = intermediates['pred_x0'][-1]
-
- x_sample = model.decode_first_stage(sample)
-
- try:
- x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
- log["sample_noquant"] = x_sample_noquant
- log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
- except:
- pass
-
- log["sample"] = x_sample
- log["time"] = t1 - t0
-
- return log
diff --git a/modules/localization.py b/modules/localization.py
index b1810cda..f6a6f2fb 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -3,6 +3,7 @@ import os
import sys
import traceback
+
localizations = {}
@@ -16,6 +17,11 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
+ from modules import scripts
+ for file in scripts.list_scripts("localizations", ".json"):
+ fn, ext = os.path.splitext(file.filename)
+ localizations[fn] = file.path
+
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..042a0254 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
import torch
-from modules.devices import get_optimal_device
+from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
def send_everything_to_cpu():
@@ -33,34 +32,49 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
- module.to(gpu)
+ module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
- # remove three big modules, cond, first_stage, and unet from the model and then
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
+
+ # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
+ if hasattr(sd_model.cond_stage_model, 'model'):
+ sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
+
+ # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
- sd_model.to(device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
+ sd_model.to(devices.device)
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
- # register hooks for those the first two models
+ # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
+ if sd_model.depth_model:
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
+ if hasattr(sd_model.cond_stage_model, 'model'):
+ sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
+ del sd_model.cond_stage_model.transformer
+
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
@@ -70,7 +84,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(device)
+ sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff --git a/modules/masking.py b/modules/masking.py
index fd8d9241..a5c4d2da 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
- desired_height = (x2 - x1) * ratio_processing
+ desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2
diff --git a/modules/memmon.py b/modules/memmon.py
index 9fb9b687..a7060f58 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
free, total = torch.cuda.mem_get_info()
+ self.data["free"] = free
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
+ self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
diff --git a/modules/modelloader.py b/modules/modelloader.py
index b0f2f33d..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -82,9 +82,13 @@ def cleanup_models():
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
+ move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
+ src_path = os.path.join(models_path, "BSRGAN")
+ dest_path = os.path.join(models_path, "ESRGAN")
+ move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
@@ -119,11 +123,27 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
- sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
- modules_dir = os.path.join(sd, "modules")
+ modules_dir = os.path.join(shared.script_path, "modules")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
@@ -132,22 +152,16 @@ def load_upscalers():
importlib.import_module(full_model)
except:
pass
+
datas = []
- c_o = vars(shared.cmd_opts)
+ commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
- module_name = cls.__module__
- module = importlib.import_module(module_name)
- class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
- opt_string = None
- try:
- if cmd_name in c_o:
- opt_string = c_o[cmd_name]
- except:
- pass
- scaler = class_(opt_string)
- for child in scaler.scalers:
- datas.append(child)
+ scaler = cls(commandline_options.get(cmd_name, None))
+ datas += scaler.scalers
shared.sd_upscalers = datas
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 5c5f349a..3df2c06b 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -1,14 +1,23 @@
from pyngrok import ngrok, conf, exception
-
def connect(token, port, region):
- if token == None:
+ account = None
+ if token is None:
token = 'None'
+ else:
+ if ':' in token:
+ # token = authtoken:username:password
+ account = token.split(':')[1] + ':' + token.split(':')[-1]
+ token = token.split(':')[0]
+
config = conf.PyngrokConfig(
auth_token=token, region=region
)
try:
- public_url = ngrok.connect(port, pyngrok_config=config).public_url
+ if account is None:
+ public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
+ else:
+ public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
diff --git a/modules/paths.py b/modules/paths.py
index 1e7a2fbc..4dd03a35 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places
sd_path = None
-possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
+possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..c7264aff 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -2,6 +2,7 @@ import json
import math
import os
import sys
+import warnings
import torch
import numpy as np
@@ -12,15 +13,21 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
import modules.images as images
import modules.styles
+import modules.sd_models as sd_models
+import modules.sd_vae as sd_vae
import logging
+from ldm.data.util import AddMiDaS
+from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -33,34 +40,68 @@ def setup_color_correction(image):
return correction_target
-def apply_color_correction(correction, image):
+def apply_color_correction(correction, original_image):
logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
- np.asarray(image),
+ np.asarray(original_image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
+
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+
+ return image
+
+
+def apply_overlay(image, paste_loc, index, overlays):
+ if overlays is None or index >= len(overlays):
+ return image
+
+ overlay = overlays[index]
+
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
return image
-def get_correct_sampler(p):
- if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
- return sd_samplers.samplers
- elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
- return sd_samplers.samplers_for_img2img
- elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
- return sd_samplers.samplers
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
-
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+ if sampler_index is not None:
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
+
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -73,7 +114,7 @@ class StableDiffusionProcessing():
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_index: int = sampler_index
+ self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
@@ -90,13 +131,16 @@ class StableDiffusionProcessing():
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
- self.denoising_strength: float = 0
+ self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = opts.ddim_discretize
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
+ self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
self.subseed = -1
@@ -104,16 +148,100 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_negative_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+ self.iteration = 0
+
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
+
+ def depth2img_image_conditioning(self, source_image):
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
+ transformer = AddMiDaS(model_type="dpt_hybrid")
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
+
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning = torch.nn.functional.interpolate(
+ self.sd_model.depth_model(midas_in),
+ size=conditioning_image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ (depth_min, depth_max) = torch.aminmax(conditioning)
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+ return conditioning
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
+ self.is_using_inpainting_conditioning = True
+
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ return image_conditioning
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
+ return self.depth2img_image_conditioning(source_image)
+
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
+ # Dummy zero conditioning if we're not using inpainting or depth model.
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
+ def close(self):
+ self.sd_model = None
+ self.sampler = None
+
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -121,10 +249,10 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
+ self.comments = comments
self.width = p.width
self.height = p.height
- self.sampler_index = p.sampler_index
- self.sampler = sd_samplers.samplers[p.sampler_index].name
+ self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@@ -151,17 +279,20 @@ class Processed:
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
- self.all_prompts = all_prompts or [self.prompt]
- self.all_seeds = all_seeds or [self.seed]
- self.all_subseeds = all_subseeds or [self.subseed]
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self):
obj = {
- "prompt": self.prompt,
+ "prompt": self.all_prompts[0],
"all_prompts": self.all_prompts,
- "negative_prompt": self.negative_prompt,
+ "negative_prompt": self.all_negative_prompts[0],
+ "all_negative_prompts": self.all_negative_prompts,
"seed": self.seed,
"all_seeds": self.all_seeds,
"subseed": self.subseed,
@@ -169,8 +300,7 @@ class Processed:
"subseed_strength": self.subseed_strength,
"width": self.width,
"height": self.height,
- "sampler_index": self.sampler_index,
- "sampler": self.sampler,
+ "sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
"steps": self.steps,
"batch_size": self.batch_size,
@@ -186,11 +316,12 @@ class Processed:
"styles": self.styles,
"job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip,
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
}
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
+ def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
@@ -210,13 +341,14 @@ def slerp(val, low, high):
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -256,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
- if opts.eta_noise_seed_delta > 0:
- torch.manual_seed(seed + opts.eta_noise_seed_delta)
+ if eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -297,20 +429,23 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = {
"Steps": p.steps,
- "Sampler": get_correct_sampler(p)[p.sampler_index].name,
+ "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
@@ -318,14 +453,38 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+
+ try:
+ for k, v in p.override_settings.items():
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
+ if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
+
+ res = process_images_inner(p)
+
+ finally:
+ # restore opts to original state
+ if p.override_settings_restore_afterwards:
+ for k, v in stored_opts.items():
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
+
+ return res
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
@@ -333,10 +492,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
devices.torch_gc()
seed = get_fixed_seed(p.seed)
@@ -347,57 +502,71 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
comments = {}
- shared.prompt_styles.apply_styles(p)
-
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
+
+ if type(p.negative_prompt) == list:
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
+ else:
+ p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.process(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
+ if p.scripts is not None:
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@@ -408,10 +577,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- samples_ddim = samples_ddim.to(devices.dtype_vae)
- x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@@ -421,9 +590,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- if opts.filter_nsfw:
- import modules.safety as safety
- x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+ if p.scripts is not None:
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -442,22 +610,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- if p.overlay_images is not None and i < len(p.overlay_images):
- overlay = p.overlay_images[i]
-
- if p.paste_to is not None:
- x, y, w, h = p.paste_to
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
@@ -468,7 +625,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -490,23 +647,33 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
- self.firstphase_width = firstphase_width
- self.firstphase_height = firstphase_height
- self.truncate_x = 0
- self.truncate_y = 0
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+
+ if firstphase_width != 0 or firstphase_height != 0:
+ print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+ self.hr_scale = self.width / firstphase_width
+ self.width = firstphase_width
+ self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -515,48 +682,50 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- if self.firstphase_width == 0 or self.firstphase_height == 0:
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- firstphase_width_truncated = int(scale * self.width)
- firstphase_height_truncated = int(scale * self.height)
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ if self.hr_upscaler is not None:
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
- else:
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- width_ratio = self.width / self.firstphase_width
- height_ratio = self.height / self.firstphase_height
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and latent_scale_mode is None:
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
- if width_ratio > height_ratio:
- firstphase_width_truncated = self.firstphase_width
- firstphase_height_truncated = self.firstphase_width * self.height / self.width
- else:
- firstphase_width_truncated = self.firstphase_height * self.width / self.height
- firstphase_height_truncated = self.firstphase_height
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
- self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
- self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ if not self.enable_hr:
+ return samples
+ target_width = int(self.width * self.hr_scale)
+ target_height = int(self.height * self.hr_scale)
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+ def save_intermediate(image, index):
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
- if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- return samples
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ if latent_scale_mode is not None:
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
- if opts.use_scale_latent_for_hires_fix:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -566,7 +735,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
+
+ save_intermediate(image, i)
+
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
@@ -577,17 +749,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
@@ -595,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -603,7 +777,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
- #self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
@@ -611,65 +784,68 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
- if self.image_mask is not None:
- self.image_mask = self.image_mask.convert('L')
+ image_mask = self.image_mask
- if self.inpainting_mask_invert:
- self.image_mask = ImageOps.invert(self.image_mask)
+ if image_mask is not None:
+ image_mask = image_mask.convert('L')
- #self.image_unblurred_mask = self.image_mask
+ if self.inpainting_mask_invert:
+ image_mask = ImageOps.invert(image_mask)
if self.mask_blur > 0:
- self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
+ image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
if self.inpaint_full_res:
- self.mask_for_overlay = self.image_mask
- mask = self.image_mask.convert('L')
+ self.mask_for_overlay = image_mask
+ mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
- self.image_mask = images.resize_image(2, mask, self.width, self.height)
+ image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
- self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
- np_mask = np.array(self.image_mask)
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+ np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = []
- latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections:
self.color_corrections = []
imgs = []
for img in self.init_images:
- image = img.convert("RGB")
+ image = images.flatten(img, opts.img2img_background_color)
- if crop_region is None:
+ if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
- if self.image_mask is not None:
+ if image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
+ # crop_region is not None if we are doing inpaint full res
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
- if self.image_mask is not None:
+ if image_mask is not None:
if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask)
@@ -685,6 +861,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
@@ -697,7 +877,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
- if self.image_mask is not None:
+ if self.resize_mode == 3:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+ if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
@@ -714,10 +897,16 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ if self.initial_noise_multiplier != 1.0:
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
+ x *= self.initial_noise_multiplier
+
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
@@ -725,4 +914,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples \ No newline at end of file
+ return samples
diff --git a/modules/safe.py b/modules/safe.py
index 399165a1..82d44be3 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -23,23 +23,30 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name == 'scalar':
- return numpy.core.multiarray.scalar
- if module == 'numpy' and name == 'dtype':
- return numpy.dtype
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(numpy.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
@@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
- raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+ raise Exception(f"global '{module}/{name}' is forbidden")
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
- if name in allowed_zip_names:
- continue
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
-def check_pt(filename):
+def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
- with z.open('archive/data.pkl') as file:
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@@ -85,33 +97,96 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this function is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename)
+ check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler
+
+ global_extra_handler = None
+
+
unsafe_torch_load = torch.load
torch.load = load
+global_extra_handler = None
+
diff --git a/modules/safety.py b/modules/safety.py
deleted file mode 100644
index cff4b278..00000000
--- a/modules/safety.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
- global safety_feature_extractor, safety_checker
-
- if safety_feature_extractor is None:
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
- return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
- x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
- x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- return x
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..de69fd9f
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,281 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+from typing import Optional
+
+from fastapi import FastAPI
+from gradio import Blocks
+
+
+def report_exception(c, job):
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+class ImageSaveParams:
+ def __init__(self, image, p, filename, pnginfo):
+ self.image = image
+ """the PIL image itself"""
+
+ self.p = p
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
+
+ self.filename = filename
+ """name of file that the image would be saved to"""
+
+ self.pnginfo = pnginfo
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+
+
+class CFGDenoiserParams:
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.image_cond = image_cond
+ """Conditioning image"""
+
+ self.sigma = sigma
+ """Current sigma noise step value"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
+class UiTrainTabParams:
+ def __init__(self, txt2img_preview_params):
+ self.txt2img_preview_params = txt2img_preview_params
+
+
+class ImageGridLoopParams:
+ def __init__(self, imgs, cols, rows):
+ self.imgs = imgs
+ self.cols = cols
+ self.rows = rows
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callback_map = dict(
+ callbacks_app_started=[],
+ callbacks_model_loaded=[],
+ callbacks_ui_tabs=[],
+ callbacks_ui_train_tabs=[],
+ callbacks_ui_settings=[],
+ callbacks_before_image_saved=[],
+ callbacks_image_saved=[],
+ callbacks_cfg_denoiser=[],
+ callbacks_before_component=[],
+ callbacks_after_component=[],
+ callbacks_image_grid=[],
+)
+
+
+def clear_callbacks():
+ for callback_list in callback_map.values():
+ callback_list.clear()
+
+
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
+ for c in callback_map['callbacks_app_started']:
+ try:
+ c.callback(demo, app)
+ except Exception:
+ report_exception(c, 'app_started_callback')
+
+
+def model_loaded_callback(sd_model):
+ for c in callback_map['callbacks_model_loaded']:
+ try:
+ c.callback(sd_model)
+ except Exception:
+ report_exception(c, 'model_loaded_callback')
+
+
+def ui_tabs_callback():
+ res = []
+
+ for c in callback_map['callbacks_ui_tabs']:
+ try:
+ res += c.callback() or []
+ except Exception:
+ report_exception(c, 'ui_tabs_callback')
+
+ return res
+
+
+def ui_train_tabs_callback(params: UiTrainTabParams):
+ for c in callback_map['callbacks_ui_train_tabs']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'callbacks_ui_train_tabs')
+
+
+def ui_settings_callback():
+ for c in callback_map['callbacks_ui_settings']:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'ui_settings_callback')
+
+
+def before_image_saved_callback(params: ImageSaveParams):
+ for c in callback_map['callbacks_before_image_saved']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'before_image_saved_callback')
+
+
+def image_saved_callback(params: ImageSaveParams):
+ for c in callback_map['callbacks_image_saved']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'image_saved_callback')
+
+
+def cfg_denoiser_callback(params: CFGDenoiserParams):
+ for c in callback_map['callbacks_cfg_denoiser']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoiser_callback')
+
+
+def before_component_callback(component, **kwargs):
+ for c in callback_map['callbacks_before_component']:
+ try:
+ c.callback(component, **kwargs)
+ except Exception:
+ report_exception(c, 'before_component_callback')
+
+
+def after_component_callback(component, **kwargs):
+ for c in callback_map['callbacks_after_component']:
+ try:
+ c.callback(component, **kwargs)
+ except Exception:
+ report_exception(c, 'after_component_callback')
+
+
+def image_grid_callback(params: ImageGridLoopParams):
+ for c in callback_map['callbacks_image_grid']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'image_grid')
+
+
+def add_callback(callbacks, fun):
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+ callbacks.append(ScriptCallback(filename, fun))
+
+
+def remove_current_script_callbacks():
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ if filename == 'unknown file':
+ return
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+ callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+ callback_list.remove(callback_to_remove)
+
+
+def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
+ add_callback(callback_map['callbacks_app_started'], callback)
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument"""
+ add_callback(callback_map['callbacks_model_loaded'], callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
+
+
+def on_ui_train_tabs(callback):
+ """register a function to be called when the UI is creating new tabs for the train tab.
+ Create your new tabs with gr.Tab.
+ """
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settings are populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ add_callback(callback_map['callbacks_ui_settings'], callback)
+
+
+def on_before_image_saved(callback):
+ """register a function to be called before an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
+ """
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
+
+
+def on_image_saved(callback):
+ """register a function to be called after an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
+ """
+ add_callback(callback_map['callbacks_image_saved'], callback)
+
+
+def on_cfg_denoiser(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+
+
+def on_before_component(callback):
+ """register a function to be called before a component is created.
+ The callback is called with arguments:
+ - component - gradio component that is about to be created.
+ - **kwargs - args to gradio.components.IOComponent.__init__ function
+
+ Use elem_id/label fields of kwargs to figure out which component it is.
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
+ """
+ add_callback(callback_map['callbacks_before_component'], callback)
+
+
+def on_after_component(callback):
+ """register a function to be called after a component is created. See on_before_component for more."""
+ add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid(callback):
+ """register a function to be called before making an image grid.
+ The callback is called with one argument:
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
+ """
+ add_callback(callback_map['callbacks_image_grid'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
new file mode 100644
index 00000000..f93f0951
--- /dev/null
+++ b/modules/script_loading.py
@@ -0,0 +1,34 @@
+import os
+import sys
+import traceback
+from types import ModuleType
+
+
+def load_module(path):
+ with open(path, "r", encoding="utf8") as file:
+ text = file.read()
+
+ compiled = compile(text, path, 'exec')
+ module = ModuleType(os.path.basename(path))
+ exec(compiled, module.__dict__)
+
+ return module
+
+
+def preload_extensions(extensions_dir, parser):
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ preload_script = os.path.join(extensions_dir, dirname, "preload.py")
+ if not os.path.isfile(preload_script):
+ continue
+
+ try:
+ module = load_module(preload_script)
+ if hasattr(module, 'preload'):
+ module.preload(parser)
+
+ except Exception:
+ print(f"Error running preload() for {preload_script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/scripts.py b/modules/scripts.py
index 1039fa9c..722f8685 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,86 +1,211 @@
import os
import sys
import traceback
+from collections import namedtuple
-import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks, extensions, script_loading
+
+AlwaysVisible = object()
+
class Script:
filename = None
args_from = None
args_to = None
+ alwayson = False
+
+ is_txt2img = False
+ is_img2img = False
+
+ """A gr.Group component that has all script's UI inside it"""
+ group = None
+
+ infotext_fields = None
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+ """
- # The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
raise NotImplementedError()
- # How the script is displayed in the UI. See https://gradio.app/docs/#components
- # for the different UI components you can use and how to create them.
- # Most UI components can return a value, such as a boolean for a checkbox.
- # The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned components will be passed to run() and process() functions.
+ """
+
pass
- # Determines when the script should be shown in the dropdown menu via the
- # returned value. As an example:
- # is_img2img is True if the current tab is img2img, and False if it is txt2img.
- # Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+
return True
- # This is where the additional processing is implemented. The parameters include
- # self, the model object "p" (a StableDiffusionProcessing class, see
- # processing.py), and the parameters returned by the ui method.
- # Custom functions can be defined here, and additional libraries can be imported
- # to be used in processing. The return value should be a Processed object, which is
- # what is returned by the process_images method.
- def run(self, *args):
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+
+ Usually the processing is done by calling the processing.process_images function.
+
+ args contains all values returned by components from ui()
+ """
+
raise NotImplementedError()
- # The description method is currently unused.
- # To add a description that appears when hovering over the title, amend the "titles"
- # dict in script.js to include the script title (returned by title) as a key, and
- # your description as the value.
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def process_batch(self, p, *args, **kwargs):
+ """
+ Same as process(), but called for every batch.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def before_component(self, component, **kwargs):
+ """
+ Called before a component is created.
+ Use elem_id/label fields of kwargs to figure out which component it is.
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
+ """
+
+ pass
+
+ def after_component(self, component, **kwargs):
+ """
+ Called after a component is created. Same as above.
+ """
+
+ pass
+
def describe(self):
+ """unused"""
return ""
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
-def load_scripts(basedir):
- if not os.path.exists(basedir):
- return
+def list_scripts(scriptdirname, extension):
+ scripts_list = []
- for filename in sorted(os.listdir(basedir)):
- path = os.path.join(basedir, filename)
+ basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(basedir):
+ for filename in sorted(os.listdir(basedir)):
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- if os.path.splitext(path)[1].lower() != '.py':
- continue
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
+
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return scripts_list
+
+
+def list_files_with_name(filename):
+ res = []
- if not os.path.isfile(path):
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
+
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
continue
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(path):
+ res.append(path)
+
+ return res
+
+
+def load_scripts():
+ global current_basedir
+ scripts_data.clear()
+ script_callbacks.clear_callbacks()
+
+ scripts_list = list_scripts("scripts", ".py")
+
+ syspath = sys.path
+
+ for scriptfile in sorted(scripts_list):
try:
- with open(path, "r", encoding="utf8") as file:
- text = file.read()
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
- from types import ModuleType
- compiled = compile(text, path, 'exec')
- module = ModuleType(filename)
- exec(compiled, module.__dict__)
+ module = script_loading.load_module(scriptfile.path)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append((script_class, path))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
- print(f"Error loading script: {filename}", file=sys.stderr)
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ sys.path = syspath
+ current_basedir = paths.script_path
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -96,64 +221,94 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
self.titles = []
+ self.infotext_fields = []
- def setup_ui(self, is_img2img):
- for script_class, path in scripts_data:
+ def initialize_scripts(self, is_img2img):
+ self.scripts.clear()
+ self.alwayson_scripts.clear()
+ self.selectable_scripts.clear()
+
+ for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
+ script.is_txt2img = not is_img2img
+ script.is_img2img = is_img2img
- if not script.show(is_img2img):
- continue
+ visibility = script.show(script.is_img2img)
- self.scripts.append(script)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
- inputs = [dropdown]
+ def setup_ui(self):
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
- for script in self.scripts:
+ inputs = [None]
+ inputs_alwayson = [True]
+
+ def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
- controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None:
- continue
+ return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- control.visible = False
+
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
+ for script in self.alwayson_scripts:
+ with gr.Group() as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ script.group = group
+
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
+ inputs[0] = dropdown
+
+ for script in self.selectable_scripts:
+ with gr.Group(visible=False) as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ script.group = group
+
def select_script(script_index):
- if 0 < script_index <= len(self.scripts):
- script = self.scripts[script_index-1]
- args_from = script.args_from
- args_to = script.args_to
- else:
- args_from = 0
- args_to = 0
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
+ """called when an initial value is set from ui-config.json to show script's UI components"""
+
if title == 'None':
return
+
script_index = self.titles.index(title)
- script = self.scripts[script_index]
- for i in range(script.args_from, script.args_to):
- inputs[i].visible = True
+ self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
+
dropdown.change(
fn=select_script,
inputs=[dropdown],
- outputs=inputs
+ outputs=[script.group for script in self.selectable_scripts]
)
return inputs
@@ -164,7 +319,7 @@ class ScriptRunner:
if script_index == 0:
return None
- script = self.scripts[script_index-1]
+ script = self.selectable_scripts[script_index-1]
if script is None:
return None
@@ -176,40 +331,112 @@ class ScriptRunner:
return processed
- def reload_sources(self):
+ def process(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process(p, *script_args)
+ except Exception:
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess_batch(self, p, images, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def before_component(self, component, **kwargs):
+ for script in self.scripts:
+ try:
+ script.before_component(component, **kwargs)
+ except Exception:
+ print(f"Error running before_component: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def after_component(self, component, **kwargs):
+ for script in self.scripts:
+ try:
+ script.after_component(component, **kwargs)
+ except Exception:
+ print(f"Error running after_component: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
- with open(script.filename, "r", encoding="utf8") as file:
- args_from = script.args_from
- args_to = script.args_to
- filename = script.filename
- text = file.read()
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
- from types import ModuleType
+ module = cache.get(filename, None)
+ if module is None:
+ module = script_loading.load_module(script.filename)
+ cache[filename] = module
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- self.scripts[si] = script_class()
- self.scripts[si].filename = filename
- self.scripts[si].args_from = args_from
- self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_current: ScriptRunner = None
+
def reload_script_body_only():
- scripts_txt2img.reload_sources()
- scripts_img2img.reload_sources()
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
-def reload_scripts(basedir):
+def reload_scripts():
global scripts_txt2img, scripts_img2img
- scripts_data.clear()
- load_scripts(basedir)
+ load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
+
+def IOComponent_init(self, *args, **kwargs):
+ if scripts_current is not None:
+ scripts_current.before_component(self, **kwargs)
+
+ script_callbacks.before_component_callback(self, **kwargs)
+
+ res = original_IOComponent_init(self, *args, **kwargs)
+
+ script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts_current is not None:
+ scripts_current.after_component(self, **kwargs)
+
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+gr.components.IOComponent.__init__ = IOComponent_init
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
deleted file mode 100644
index 36a996bf..00000000
--- a/modules/scunet_model.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.scunet_model_arch import SCUNet as net
-
-
-class UpscalerScuNET(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "ScuNET"
- self.model_name = "ScuNET GAN"
- self.model_name2 = "ScuNET PSNR"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
- self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pth"])
- scalers = []
- add_model2 = True
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- if name == self.model_name2 or file == self.model_url2:
- add_model2 = False
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- if add_model2:
- scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
- scalers.append(scaler_data2)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
-
- model = self.load_model(selected_file)
- if model is None:
- return img
-
- device = devices.device_scunet
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
-
- img = img.to(device)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- device = devices.device_scunet
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- model = model.to(device)
-
- return model
-
diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py
deleted file mode 100644
index 43ca8d36..00000000
--- a/modules/scunet_model_arch.py
+++ /dev/null
@@ -1,265 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-import torch.nn as nn
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from timm.models.layers import trunc_normal_, DropPath
-
-
-class WMSA(nn.Module):
- """ Self-attention module in Swin Transformer
- """
-
- def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.head_dim = head_dim
- self.scale = self.head_dim ** -0.5
- self.n_heads = input_dim // head_dim
- self.window_size = window_size
- self.type = type
- self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
-
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
-
- self.linear = nn.Linear(self.input_dim, self.output_dim)
-
- trunc_normal_(self.relative_position_params, std=.02)
- self.relative_position_params = torch.nn.Parameter(
- self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
- 2).transpose(
- 0, 1))
-
- def generate_mask(self, h, w, p, shift):
- """ generating the mask of SW-MSA
- Args:
- shift: shift parameters in CyclicShift.
- Returns:
- attn_mask: should be (1 1 w p p),
- """
- # supporting square.
- attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
- if self.type == 'W':
- return attn_mask
-
- s = p - shift
- attn_mask[-1, :, :s, :, s:, :] = True
- attn_mask[-1, :, s:, :, :s, :] = True
- attn_mask[:, -1, :, :s, :, s:] = True
- attn_mask[:, -1, :, s:, :, :s] = True
- attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
- return attn_mask
-
- def forward(self, x):
- """ Forward pass of Window Multi-head Self-attention module.
- Args:
- x: input tensor with shape of [b h w c];
- attn_mask: attention mask, fill -inf where the value is True;
- Returns:
- output: tensor shape [b h w c]
- """
- if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
- x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
- h_windows = x.size(1)
- w_windows = x.size(2)
- # square validation
- # assert h_windows == w_windows
-
- x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
- qkv = self.embedding_layer(x)
- q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
- sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
- # Adding learnable relative embedding
- sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
- # Using Attn Mask to distinguish different subwindows.
- if self.type != 'W':
- attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
- sim = sim.masked_fill_(attn_mask, float("-inf"))
-
- probs = nn.functional.softmax(sim, dim=-1)
- output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
- output = rearrange(output, 'h b w p c -> b w p (h c)')
- output = self.linear(output)
- output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
-
- if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2))
- return output
-
- def relative_embedding(self):
- cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
- relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
- # negative is allowed
- return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
-
-
-class Block(nn.Module):
- def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer Block
- """
- super(Block, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- assert type in ['W', 'SW']
- self.type = type
- if input_resolution <= window_size:
- self.type = 'W'
-
- self.ln1 = nn.LayerNorm(input_dim)
- self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.ln2 = nn.LayerNorm(input_dim)
- self.mlp = nn.Sequential(
- nn.Linear(input_dim, 4 * input_dim),
- nn.GELU(),
- nn.Linear(4 * input_dim, output_dim),
- )
-
- def forward(self, x):
- x = x + self.drop_path(self.msa(self.ln1(x)))
- x = x + self.drop_path(self.mlp(self.ln2(x)))
- return x
-
-
-class ConvTransBlock(nn.Module):
- def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer and Conv Block
- """
- super(ConvTransBlock, self).__init__()
- self.conv_dim = conv_dim
- self.trans_dim = trans_dim
- self.head_dim = head_dim
- self.window_size = window_size
- self.drop_path = drop_path
- self.type = type
- self.input_resolution = input_resolution
-
- assert self.type in ['W', 'SW']
- if self.input_resolution <= self.window_size:
- self.type = 'W'
-
- self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
- self.type, self.input_resolution)
- self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
- self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
-
- self.conv_block = nn.Sequential(
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
- )
-
- def forward(self, x):
- conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
- conv_x = self.conv_block(conv_x) + conv_x
- trans_x = Rearrange('b c h w -> b h w c')(trans_x)
- trans_x = self.trans_block(trans_x)
- trans_x = Rearrange('b h w c -> b c h w')(trans_x)
- res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
- x = x + res
-
- return x
-
-
-class SCUNet(nn.Module):
- # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
- def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
- super(SCUNet, self).__init__()
- if config is None:
- config = [2, 2, 2, 2, 2, 2, 2]
- self.config = config
- self.dim = dim
- self.head_dim = 32
- self.window_size = 8
-
- # drop path rate for each layer
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
-
- self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
-
- begin = 0
- self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[0])] + \
- [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
-
- begin += config[0]
- self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[1])] + \
- [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
-
- begin += config[1]
- self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[2])] + \
- [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
-
- begin += config[2]
- self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 8)
- for i in range(config[3])]
-
- begin += config[3]
- self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[4])]
-
- begin += config[4]
- self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[5])]
-
- begin += config[5]
- self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[6])]
-
- self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
-
- self.m_head = nn.Sequential(*self.m_head)
- self.m_down1 = nn.Sequential(*self.m_down1)
- self.m_down2 = nn.Sequential(*self.m_down2)
- self.m_down3 = nn.Sequential(*self.m_down3)
- self.m_body = nn.Sequential(*self.m_body)
- self.m_up3 = nn.Sequential(*self.m_up3)
- self.m_up2 = nn.Sequential(*self.m_up2)
- self.m_up1 = nn.Sequential(*self.m_up1)
- self.m_tail = nn.Sequential(*self.m_tail)
- # self.apply(self._init_weights)
-
- def forward(self, x0):
-
- h, w = x0.size()[-2:]
- paddingBottom = int(np.ceil(h / 64) * 64 - h)
- paddingRight = int(np.ceil(w / 64) * 64 - w)
- x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
-
- x1 = self.m_head(x0)
- x2 = self.m_down1(x1)
- x3 = self.m_down2(x2)
- x4 = self.m_down3(x3)
- x = self.m_body(x4)
- x = self.m_up3(x + x4)
- x = self.m_up2(x + x3)
- x = self.m_up1(x + x2)
- x = self.m_tail(x + x1)
-
- x = x[..., :h, :w]
-
- return x
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0) \ No newline at end of file
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 984b35c4..fa2cd4bb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -1,60 +1,81 @@
-import math
-import os
-import sys
-import traceback
import torch
-import numpy as np
-from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts
+from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules.hypernetworks import hypernetwork
+from modules.shared import cmd_opts
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+import ldm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+# new memory efficient cross attention blocks do not support hypernets and we already
+# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
+ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
+ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
+
+# silence new console spam from SD2
+ldm.modules.attention.print = lambda *args: None
+ldm.modules.diffusionmodules.model.print = lambda *args: None
+
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
+ optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ optimization_method = 'xformers'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'V1'
else:
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+ optimization_method = 'Doggettx'
+ return optimization_method
-def undo_optimizations():
- from modules.hypernetworks import hypernetwork
+def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+def fix_checkpoint():
+ ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack:
@@ -63,18 +84,31 @@ class StableDiffusionModelHijack:
layers = None
circular_enabled = False
clip = None
+ optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
- self.clip = m.cond_stage_model
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+ self.optimization_method = apply_optimizations()
- apply_optimizations()
+ self.clip = m.cond_stage_model
+
+ fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -86,12 +120,23 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
- if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
+
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+
+ elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ self.apply_circular(False)
+ self.layers = None
+ self.clip = None
def apply_circular(self, enable):
if self.circular_enabled == enable:
@@ -107,263 +152,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
-
-
-class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
- def __init__(self, wrapped, hijack):
- super().__init__()
- self.wrapped = wrapped
- self.hijack: StableDiffusionModelHijack = hijack
- self.tokenizer = wrapped.tokenizer
- self.token_mults = {}
-
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
-
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
- for text, ident in tokens_with_parens:
- mult = 1.0
- for c in text:
- if c == '[':
- mult /= 1.1
- if c == ']':
- mult *= 1.1
- if c == '(':
- mult *= 1.1
- if c == ')':
- mult /= 1.1
-
- if mult != 1.0:
- self.token_mults[ident] = mult
-
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_end = self.wrapped.tokenizer.eos_token_id
-
- if opts.enable_emphasis:
- parsed = prompt_parser.parse_prompt_attention(line)
- else:
- parsed = [[line, 1.0]]
-
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
-
- fixes = []
- remade_tokens = []
- multipliers = []
- last_comma = -1
-
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
-
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
-
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-
- if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
-
- cache = {}
- batch_multipliers = []
- for line in texts:
- if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
- else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
- token_count = max(current_token_count, token_count)
-
- cache[line] = (remade_tokens, fixes, multipliers)
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
-
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
-
- def process_text_old(self, text):
- id_start = self.wrapped.tokenizer.bos_token_id
- id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- overflowing_words = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
-
- cache = {}
- batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
-
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
-
- return z
-
-
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-
- tokens = torch.asarray(remade_batch_tokens).to(device)
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
-
- if opts.CLIP_stop_at_last_layers > 1:
- z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
- else:
- z = outputs.last_hidden_state
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
- original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- new_mean = z.mean()
- z *= original_mean / new_mean
-
- return z
+ return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
@@ -385,8 +175,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
@@ -403,3 +193,19 @@ def add_circular_option_to_conv_2d():
model_hijack = StableDiffusionModelHijack()
+
+
+def register_buffer(self, name, attr):
+ """
+ Fix register buffer bug for Mac OS.
+ """
+
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
+
+ setattr(self, name, attr)
+
+
+ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
+ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
new file mode 100644
index 00000000..5712972f
--- /dev/null
+++ b/modules/sd_hijack_checkpoint.py
@@ -0,0 +1,10 @@
+from torch.utils.checkpoint import checkpoint
+
+def BasicTransformerBlock_forward(self, x, context=None):
+ return checkpoint(self._forward, x, context)
+
+def AttentionBlock_forward(self, x):
+ return checkpoint(self._forward, x)
+
+def ResBlock_forward(self, x, emb):
+ return checkpoint(self._forward, x, emb) \ No newline at end of file
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
new file mode 100644
index 00000000..ca92b142
--- /dev/null
+++ b/modules/sd_hijack_clip.py
@@ -0,0 +1,303 @@
+import math
+
+import torch
+
+from modules import prompt_parser, devices
+from modules.shared import opts
+
+def get_target_prompt_token_count(token_count):
+ return math.ceil(max(token_count, 1) / 75) * 75
+
+
+class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ def __init__(self, wrapped, hijack):
+ super().__init__()
+ self.wrapped = wrapped
+ self.hijack = hijack
+
+ def tokenize(self, texts):
+ raise NotImplementedError
+
+ def encode_with_transformers(self, tokens):
+ raise NotImplementedError
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ raise NotImplementedError
+
+ def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ if opts.enable_emphasis:
+ parsed = prompt_parser.parse_prompt_attention(line)
+ else:
+ parsed = [[line, 1.0]]
+
+ tokenized = self.tokenize([text for text, _ in parsed])
+
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ last_comma = -1
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [self.id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
+ if embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ iteration = len(remade_tokens) // 75
+ if (len(remade_tokens) + emb_len) // 75 != iteration:
+ rem = (75 * (iteration + 1) - len(remade_tokens))
+ remade_tokens += [self.id_end] * rem
+ multipliers += [1.0] * rem
+ iteration += 1
+ fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ token_count = len(remade_tokens)
+ prompt_target_length = get_target_prompt_token_count(token_count)
+ tokens_to_add = prompt_target_length - len(remade_tokens)
+
+ remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
+
+ return remade_tokens, fixes, multipliers, token_count
+
+ def process_text(self, texts):
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_multipliers = []
+ for line in texts:
+ if line in cache:
+ remade_tokens, fixes, multipliers = cache[line]
+ else:
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ token_count = max(current_token_count, token_count)
+
+ cache[line] = (remade_tokens, fixes, multipliers)
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+ def process_text_old(self, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+ def forward(self, text):
+ use_old = opts.use_old_emphasis_implementation
+ if use_old:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ else:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ if use_old:
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
+
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
+
+ self.hijack.fixes = []
+ for unfiltered in hijack_fixes:
+ fixes = []
+ for fix in unfiltered:
+ if fix[0] == i:
+ fixes.append(fix[1])
+ self.hijack.fixes.append(fixes)
+
+ tokens = []
+ multipliers = []
+ for j in range(len(remade_batch_tokens)):
+ if len(remade_batch_tokens[j]) > 0:
+ tokens.append(remade_batch_tokens[j][:75])
+ multipliers.append(batch_multipliers[j][:75])
+ else:
+ tokens.append([self.id_end] * 75)
+ multipliers.append([1.0] * 75)
+
+ z1 = self.process_tokens(tokens, multipliers)
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+ tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+
+ if self.id_end != self.id_pad:
+ for batch_pos in range(len(remade_batch_tokens)):
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
+ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
+
+ z = self.encode_with_transformers(tokens)
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ original_mean = z.mean()
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z *= original_mean / new_mean
+
+ return z
+
+
+class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+ self.tokenizer = wrapped.tokenizer
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
+
+ self.token_mults = {}
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+
+ if mult != 1.0:
+ self.token_mults[ident] = mult
+
+ self.id_start = self.wrapped.tokenizer.bos_token_id
+ self.id_end = self.wrapped.tokenizer.eos_token_id
+ self.id_pad = self.id_end
+
+ def tokenize(self, texts):
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+
+ if opts.CLIP_stop_at_last_layers > 1:
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.transformer.text_model.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..31d2c898
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,111 @@
+import os
+import torch
+
+from einops import repeat
+from omegaconf import ListConfig
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
+
+def should_hijack_inpainting(checkpoint_info):
+ from modules import sd_models
+
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
+ return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
+
+
+def do_inpainting_hijack():
+ # p_sample_plms is needed because PLMS can't work with dicts as conditionings
+
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
new file mode 100644
index 00000000..f733e852
--- /dev/null
+++ b/modules/sd_hijack_open_clip.py
@@ -0,0 +1,37 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+tokenizer = open_clip.tokenizer._tokenizer
+
+
+class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
+ self.id_start = tokenizer.encoder["<start_of_text>"]
+ self.id_end = tokenizer.encoder["<end_of_text>"]
+ self.id_pad = 0
+
+ def tokenize(self, texts):
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
+
+ tokenized = [tokenizer.encode(text) for text in texts]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
+ z = self.wrapped.encode_with_transformer(tokens)
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ ids = tokenizer.encode(init_text)
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 98123fbf..02c87f40 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -127,7 +127,7 @@ def check_for_psutil():
invokeAI_mps_available = check_for_psutil()
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r
def einsum_op_mps_v1(q, k, v):
- if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[1] <= 4096:
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
- if mem_total_gb >= 32:
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
new file mode 100644
index 00000000..18daf8c1
--- /dev/null
+++ b/modules/sd_hijack_unet.py
@@ -0,0 +1,30 @@
+import torch
+
+
+class TorchHijackForUnet:
+ """
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
+ this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
+ """
+
+ def __getattr__(self, item):
+ if item == 'cat':
+ return self.cat
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+ def cat(self, tensors, *args, **kwargs):
+ if len(tensors) == 2:
+ a, b = tensors
+ if a.shape[-2:] != b.shape[-2:]:
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
+
+ tensors = (a, b)
+
+ return torch.cat(tensors, *args, **kwargs)
+
+
+th = TorchHijackForUnet()
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eae22e87..76a89e88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,26 +1,33 @@
import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
+import re
+import safetensors.torch
from omegaconf import OmegaConf
+from os import mkdir
+from urllib import request
+import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -32,15 +39,26 @@ def setup_model():
os.makedirs(model_path)
list_models()
+ enable_midas_autodownload()
-def checkpoint_tiles():
- return sorted([x.title for x in checkpoints_list.values()])
+def checkpoint_tiles():
+ convert = lambda name: int(name) if name.isdigit() else name.lower()
+ alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
+ return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+
+
+def find_checkpoint_config(info):
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return shared.cmd_opts.config
def list_models():
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
+ model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
@@ -63,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -71,12 +89,7 @@ def list_models():
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- basename, _ = os.path.splitext(filename)
- config = basename + ".yaml"
- if not os.path.exists(config):
- config = shared.cmd_opts.config
-
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString):
@@ -101,18 +114,19 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
+
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
- print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -138,8 +152,8 @@ def transform_checkpoint_dict_key(k):
def get_state_dict_from_checkpoint(pl_sd):
- if "state_dict" in pl_sd:
- pl_sd = pl_sd["state_dict"]
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
+ pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
@@ -154,64 +168,156 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
-def load_model_weights(model, checkpoint_info):
+def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
+ _, extension = os.path.splitext(checkpoint_file)
+ if extension.lower() == ".safetensors":
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
+
+ if print_global_state and "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ return sd
+
+
+def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if checkpoint_info not in checkpoints_loaded:
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ cache_enabled = shared.opts.sd_checkpoint_cache > 0
- pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
+ if cache_enabled and checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ else:
+ # load from file
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- sd = get_state_dict_from_checkpoint(pl_sd)
- missing, extra = model.load_state_dict(sd, strict=False)
+ sd = read_state_dict(checkpoint_file)
+ model.load_state_dict(sd, strict=False)
+ del sd
+
+ if cache_enabled:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
+ vae = model.first_stage_model
+
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.cmd_opts.no_half_vae:
+ model.first_stage_model = None
+
model.half()
+ model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
-
- if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
- vae_file = shared.cmd_opts.vae_path
-
- if os.path.exists(vae_file):
- print(f"Loading VAE weights from: {vae_file}")
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
- model.first_stage_model.load_state_dict(vae_dict)
-
model.first_stage_model.to(devices.dtype_vae)
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
- else:
- print(f"Loading weights [{sd_model_hash}] from cache")
- checkpoints_loaded.move_to_end(checkpoint_info)
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ model.logvar = model.logvar.to(devices.device) # fix for training
+
+ sd_vae.delete_base_vae()
+ sd_vae.clear_loaded_vae()
+ vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+ sd_vae.load_vae(model, vae_file)
+
-def load_model():
+def enable_midas_autodownload():
+ """
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+ When the 512-depth-ema model, and other future models like it, is loaded,
+ it calls midas.api.load_model to load the associated midas depth model.
+ This function applies a wrapper to download the model to the correct
+ location automatically.
+ """
+
+ midas_path = os.path.join(models_path, 'midas')
+
+ # stable-diffusion-stability-ai hard-codes the midas model path to
+ # a location that differs from where other scripts using this model look.
+ # HACK: Overriding the path here.
+ for k, v in midas.api.ISL_PATHS.items():
+ file_name = os.path.basename(v)
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+
+ midas_urls = {
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+ }
+
+ midas.api.load_model_inner = midas.api.load_model
+
+ def load_model_wrapper(model_type):
+ path = midas.api.ISL_PATHS[model_type]
+ if not os.path.exists(path):
+ if not os.path.exists(midas_path):
+ mkdir(midas_path)
+
+ print(f"Downloading midas model weights for {model_type} to {path}")
+ request.urlretrieve(midas_urls[model_type], path)
+ print(f"{model_type} downloaded")
+
+ return midas.api.load_model_inner(model_type)
+
+ midas.api.load_model = load_model_wrapper
+
+
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
+ checkpoint_config = find_checkpoint_config(checkpoint_info)
+
+ if checkpoint_config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_config}")
+
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
- if checkpoint_info.config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_info.config}")
+ sd_config = OmegaConf.load(checkpoint_config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+ sd_config.model.params.finetune_keys = None
+
+ if not hasattr(sd_config.model.params, "use_ema"):
+ sd_config.model.params.use_ema = False
+
+ do_inpainting_hijack()
+
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
- sd_config = OmegaConf.load(checkpoint_info.config)
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -222,21 +328,34 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
+ shared.sd_model = sd_model
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+
+ script_callbacks.model_loaded_callback(sd_model)
+
+ print("Model loaded.")
- print(f"Model loaded.")
return sd_model
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ if not sd_model:
+ sd_model = shared.sd_model
+
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_config = find_checkpoint_config(current_checkpoint_info)
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -246,12 +365,19 @@ def reload_model_weights(sd_model, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
- sd_hijack.model_hijack.hijack(sd_model)
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ print("Weights loaded.")
- print(f"Weights loaded.")
return sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..e904d860 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,32 +1,41 @@
-from collections import namedtuple
+from collections import namedtuple, deque
import numpy as np
+from math import floor
import torch
import tqdm
from PIL import Image
import inspect
import k_diffusion.sampling
+import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
@@ -40,16 +49,24 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
+all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
+samplers_map = {}
-def create_sampler_with_index(list_of_configs, index, model):
- config = list_of_configs[index]
+def create_sampler(name, model):
+ if name is not None:
+ config = all_samplers_map.get(name, None)
+ else:
+ config = all_samplers[0]
+
+ assert config is not None, f'bad sampler name: {name}'
+
sampler = config.constructor(model)
sampler.config = config
-
+
return sampler
@@ -62,6 +79,12 @@ def set_samplers():
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+ samplers_map.clear()
+ for sampler in all_samplers:
+ samplers_map[sampler.name.lower()] = sampler.name
+ for alias in sampler.aliases:
+ samplers_map[alias.lower()] = sampler.name
+
set_samplers()
@@ -71,6 +94,7 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,14 +106,34 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def sample_to_image(samples):
- x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
def store_latent(decoded):
state.current_latent = decoded
@@ -105,7 +149,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@@ -117,6 +162,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def number_of_needed_noises(self, p):
return 0
@@ -136,6 +183,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -157,6 +210,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
@@ -182,39 +241,52 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
- steps, t_enc = setup_img2img_steps(p, steps)
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+ steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
- # existing code fails with certain step counts, like 9
- try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
+ self.last_latent = x
self.step = 0
- samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
+ self.last_latent = x
self.step = 0
- steps = steps or p.steps
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
- # existing code fails with certain step counts, like 9
- try:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
- except Exception:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
+ if image_conditioning is not None:
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
+
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
@@ -228,7 +300,17 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
- def forward(self, x, sigma, uncond, cond, cond_scale):
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -239,35 +321,37 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -278,34 +362,63 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack:
- def __init__(self, kdiff_sampler):
- self.kdiff_sampler = kdiff_sampler
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
- return self.kdiff_sampler.randn_like
+ return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+# MPS fix for randn in torchsde
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
- self.sampler_noise_index = 0
self.stop_at = None
self.eta = None
self.default_eta = 1.0
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
@@ -330,26 +443,13 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p):
return p.steps
- def randn_like(self, x):
- noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
-
- if noise is not None and x.shape == noise.shape:
- res = noise
- else:
- res = torch.randn_like(x)
-
- self.sampler_noise_index += 1
- return res
-
def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0
- self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
- if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self)
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:
@@ -361,16 +461,26 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
+ def get_sigmas(self, p, steps):
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
+ if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@@ -388,20 +498,21 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
+ sigmas = self.get_sigmas(p, steps)
x = x * sigmas[0]
@@ -414,7 +525,13 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
new file mode 100644
index 00000000..ac71d62d
--- /dev/null
+++ b/modules/sd_vae.py
@@ -0,0 +1,231 @@
+import torch
+import os
+import collections
+from collections import namedtuple
+from modules import shared, devices, script_callbacks
+from modules.paths import models_path
+import glob
+from copy import deepcopy
+
+
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
+vae_dir = "VAE"
+vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
+
+
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
+default_vae_dict = {"auto": "auto", "None": None, None: None}
+default_vae_list = ["auto", "None"]
+
+
+default_vae_values = [default_vae_dict[x] for x in default_vae_list]
+vae_dict = dict(default_vae_dict)
+vae_list = list(default_vae_list)
+first_load = True
+
+
+base_vae = None
+loaded_vae_file = None
+checkpoint_info = None
+
+checkpoints_loaded = collections.OrderedDict()
+
+def get_base_vae(model):
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
+ return base_vae
+ return None
+
+
+def store_base_vae(model):
+ global base_vae, checkpoint_info
+ if checkpoint_info != model.sd_checkpoint_info:
+ assert not loaded_vae_file, "Trying to store non-base VAE!"
+ base_vae = deepcopy(model.first_stage_model.state_dict())
+ checkpoint_info = model.sd_checkpoint_info
+
+
+def delete_base_vae():
+ global base_vae, checkpoint_info
+ base_vae = None
+ checkpoint_info = None
+
+
+def restore_base_vae(model):
+ global loaded_vae_file
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
+ print("Restoring base VAE")
+ _load_vae_dict(model, base_vae)
+ loaded_vae_file = None
+ delete_base_vae()
+
+
+def get_filename(filepath):
+ return os.path.splitext(os.path.basename(filepath))[0]
+
+
+def refresh_vae_list(vae_path=vae_path, model_path=model_path):
+ global vae_dict, vae_list
+ res = {}
+ candidates = [
+ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+ ]
+ if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
+ candidates.append(shared.cmd_opts.vae_path)
+ for filepath in candidates:
+ name = get_filename(filepath)
+ res[name] = filepath
+ vae_list.clear()
+ vae_list.extend(default_vae_list)
+ vae_list.extend(list(res.keys()))
+ vae_dict.clear()
+ vae_dict.update(res)
+ vae_dict.update(default_vae_dict)
+ return vae_list
+
+
+def get_vae_from_settings(vae_file="auto"):
+ # else, we load from settings, if not set to be default
+ if vae_file == "auto" and shared.opts.sd_vae is not None:
+ # if saved VAE settings isn't recognized, fallback to auto
+ vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
+ # if VAE selected but not found, fallback to auto
+ if vae_file not in default_vae_values and not os.path.isfile(vae_file):
+ vae_file = "auto"
+ print(f"Selected VAE doesn't exist: {vae_file}")
+ return vae_file
+
+
+def resolve_vae(checkpoint_file=None, vae_file="auto"):
+ global first_load, vae_dict, vae_list
+
+ # if vae_file argument is provided, it takes priority, but not saved
+ if vae_file and vae_file not in default_vae_list:
+ if not os.path.isfile(vae_file):
+ print(f"VAE provided as function argument doesn't exist: {vae_file}")
+ vae_file = "auto"
+ # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
+ if first_load and shared.cmd_opts.vae_path is not None:
+ if os.path.isfile(shared.cmd_opts.vae_path):
+ vae_file = shared.cmd_opts.vae_path
+ shared.opts.data['sd_vae'] = get_filename(vae_file)
+ else:
+ print(f"VAE provided as command line argument doesn't exist: {vae_file}")
+ # fallback to selector in settings, if vae selector not set to act as default fallback
+ if not shared.opts.sd_vae_as_default:
+ vae_file = get_vae_from_settings(vae_file)
+ # vae-path cmd arg takes priority for auto
+ if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
+ if os.path.isfile(shared.cmd_opts.vae_path):
+ vae_file = shared.cmd_opts.vae_path
+ print(f"Using VAE provided as command line argument: {vae_file}")
+ # if still not found, try look for ".vae.pt" beside model
+ model_path = os.path.splitext(checkpoint_file)[0]
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.pt"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
+ # if still not found, try look for ".vae.ckpt" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.ckpt"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
+ # No more fallbacks for auto
+ if vae_file == "auto":
+ vae_file = None
+ # Last check, just because
+ if vae_file and not os.path.exists(vae_file):
+ vae_file = None
+
+ return vae_file
+
+
+def load_vae(model, vae_file=None):
+ global first_load, vae_dict, vae_list, loaded_vae_file
+ # save_settings = False
+
+ cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
+
+ if vae_file:
+ if cache_enabled and vae_file in checkpoints_loaded:
+ # use vae checkpoint cache
+ print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ store_base_vae(model)
+ _load_vae_dict(model, checkpoints_loaded[vae_file])
+ else:
+ assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
+ print(f"Loading VAE weights from: {vae_file}")
+ store_base_vae(model)
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ _load_vae_dict(model, vae_dict_1)
+
+ if cache_enabled:
+ # cache newly loaded vae
+ checkpoints_loaded[vae_file] = vae_dict_1.copy()
+
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
+ checkpoints_loaded.popitem(last=False) # LRU
+
+ # If vae used is not in dict, update it
+ # It will be removed on refresh though
+ vae_opt = get_filename(vae_file)
+ if vae_opt not in vae_dict:
+ vae_dict[vae_opt] = vae_file
+ vae_list.append(vae_opt)
+ elif loaded_vae_file:
+ restore_base_vae(model)
+
+ loaded_vae_file = vae_file
+
+ first_load = False
+
+
+# don't call this from outside
+def _load_vae_dict(model, vae_dict_1):
+ model.first_stage_model.load_state_dict(vae_dict_1)
+ model.first_stage_model.to(devices.dtype_vae)
+
+def clear_loaded_vae():
+ global loaded_vae_file
+ loaded_vae_file = None
+
+def reload_vae_weights(sd_model=None, vae_file="auto"):
+ from modules import lowvram, devices, sd_hijack
+
+ if not sd_model:
+ sd_model = shared.sd_model
+
+ checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_file = checkpoint_info.filename
+ vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ if loaded_vae_file == vae_file:
+ return
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ sd_model.to(devices.cpu)
+
+ sd_hijack.model_hijack.undo_hijack(sd_model)
+
+ load_vae(sd_model, vae_file)
+
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
+
+ print("VAE Weights loaded.")
+ return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
new file mode 100644
index 00000000..0a58542d
--- /dev/null
+++ b/modules/sd_vae_approx.py
@@ -0,0 +1,58 @@
+import os
+
+import torch
+from torch import nn
+from modules import devices, paths
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+ def __init__(self):
+ super(VAEApprox, self).__init__()
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+ def forward(self, x):
+ extra = 11
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ x = nn.functional.pad(x, (extra, extra, extra, extra))
+
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+ x = layer(x)
+ x = nn.functional.leaky_relu(x, 0.1)
+
+ return x
+
+
+def model():
+ global sd_vae_approx_model
+
+ if sd_vae_approx_model is None:
+ sd_vae_approx_model = VAEApprox()
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.eval()
+ sd_vae_approx_model.to(devices.device, devices.dtype)
+
+ return sd_vae_approx_model
+
+
+def cheap_approximation(sample):
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+
+ coefs = torch.tensor([
+ [0.298, 0.207, 0.208],
+ [0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]).to(sample.device)
+
+ x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+
+ return x_sample
diff --git a/modules/shared.py b/modules/shared.py
index faede821..54a6ba23 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,24 +3,27 @@ import datetime
import json
import os
import sys
+import time
+from PIL import Image
import gradio as gr
import tqdm
import modules.artists
import modules.interrogate
import modules.memmon
-import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, sd_models, localization
-from modules.hypernetworks import hypernetwork
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
+
+demo = None
+
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -39,34 +42,35 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
-parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
-parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
-parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
+parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
+parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -76,12 +80,27 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
-parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
+parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
+parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
+parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+
+script_loading.preload_extensions(extensions.extensions_dir, parser)
+script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
cmd_opts = parser.parse_args()
-restricted_opts = [
+
+restricted_opts = {
"samples_filename_pattern",
+ "directories_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
@@ -89,10 +108,23 @@ restricted_opts = [
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
+}
+
+ui_reorder_categories = [
+ "sampler",
+ "dimensions",
+ "cfg",
+ "seed",
+ "checkboxes",
+ "hires_fix",
+ "batch",
+ "scripts",
]
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
+
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
+ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -103,10 +135,12 @@ xformers_available = False
config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
-hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+hypernetworks = {}
loaded_hypernetwork = None
+
def reload_hypernetworks():
+ from modules.hypernetworks import hypernetwork
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
@@ -126,6 +160,8 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ time_start = None
+ need_restart = False
def skip(self):
self.skipped = True
@@ -134,12 +170,67 @@ class State:
self.interrupted = True
def nextjob(self):
+ if opts.show_progress_every_n_steps == -1:
+ self.do_set_current_image()
+
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def dict(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.interrupted,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_timestamp": self.job_timestamp,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return obj
+
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+ self.time_start = time.time()
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
+
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ self.do_set_current_image()
+
+ def do_set_current_image(self):
+ if self.current_latent is None:
+ return
+
+ import modules.sd_samplers
+ if opts.show_progress_grid:
+ self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+
+ self.current_image_sampling_step = self.sampling_step
state = State()
@@ -153,8 +244,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-localization.list_localizations(cmd_opts.localizations_dir)
-
def realesrgan_models_names():
import modules.realesrgan_model
@@ -162,13 +251,13 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
- self.section = None
+ self.section = section
self.refresh = refresh
@@ -179,6 +268,21 @@ def options_section(section_identifier, options_dict):
return options_dict
+def list_checkpoint_tiles():
+ import modules.sd_models
+ return modules.sd_models.checkpoint_tiles()
+
+
+def refresh_checkpoints():
+ import modules.sd_models
+ return modules.sd_models.list_models()
+
+
+def list_samplers():
+ import modules.sd_samplers
+ return modules.sd_samplers.all_samplers
+
+
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
options_templates = {}
@@ -186,7 +290,8 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
@@ -198,12 +303,19 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
+
+ "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
+ "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
+
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -221,19 +333,15 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+ "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
- "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
- "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
- "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -249,31 +357,42 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+ "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
- "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
- 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
+options_templates.update(options_section(('compatibility', "Compatibility"), {
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
+ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+}))
+
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
@@ -286,26 +405,34 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
+ "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
@@ -315,6 +442,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable those extensions"),
+}))
+
+options_templates.update()
+
class Options:
data = None
@@ -326,8 +459,19 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
- if key in self.data:
+ if key in self.data or key in self.data_labels:
+ assert not cmd_opts.freeze_settings, "changing settings is disabled"
+
+ info = opts.data_labels.get(key, None)
+ comp_args = info.component_args if info else None
+ if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
self.data[key] = value
+ return
return super(Options, self).__setattr__(key, value)
@@ -341,9 +485,33 @@ class Options:
return super(Options, self).__getattribute__(item)
+ def set(self, key, value):
+ """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
+
+ oldval = self.data.get(key, None)
+ if oldval == value:
+ return False
+
+ try:
+ setattr(self, key, value)
+ except RuntimeError:
+ return False
+
+ if self.data_labels[key].onchange is not None:
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
+
+ return True
+
def save(self, filename):
+ assert not cmd_opts.freeze_settings, "saving settings is disabled"
+
with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file)
+ json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
@@ -368,25 +536,51 @@ class Options:
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
- def onchange(self, key, func):
+ def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
- func()
+ if call:
+ func()
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+latent_upscale_default_mode = "Latent"
+latent_upscale_modes = {
+ "Latent": {"mode": "bilinear", "antialias": False},
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
+}
+
sd_upscalers = []
sd_model = None
+clip_model = None
+
progress_print_out = sys.stdout
@@ -426,3 +620,8 @@ total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+
+
+def listfiles(dirname):
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
+ return [file for file in filenames if os.path.isfile(file)]
diff --git a/modules/styles.py b/modules/styles.py
index 3bf5c5b6..ce6e71ca 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -65,17 +65,6 @@ class StyleDatabase:
def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
- def apply_styles(self, p: StableDiffusionProcessing) -> None:
- if isinstance(p.prompt, list):
- p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
- else:
- p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
-
- if isinstance(p.negative_prompt, list):
- p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
- else:
- p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
-
def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv")
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
deleted file mode 100644
index baa02e3d..00000000
--- a/modules/swinir_model.py
+++ /dev/null
@@ -1,161 +0,0 @@
-import contextlib
-import os
-
-import numpy as np
-import torch
-from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
-from tqdm import tqdm
-
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
-from modules.swinir_model_arch import SwinIR as net
-from modules.swinir_model_arch_v2 import Swin2SR as net2
-from modules.upscaler import Upscaler, UpscalerData
-
-precision_scope = (
- torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-)
-
-
-class UpscalerSwinIR(Upscaler):
- def __init__(self, dirname):
- self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
- self.model_name = "SwinIR 4x"
- self.user_path = dirname
- super().__init__()
- scalers = []
- model_files = self.find_models(ext_filter=[".pt", ".pth"])
- for model in model_files:
- if "http" in model:
- name = self.model_name
- else:
- name = modelloader.friendly_name(model)
- model_data = UpscalerData(name, model, self)
- scalers.append(model_data)
- self.scalers = scalers
-
- def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
- return img
- model = model.to(device)
- img = upscale(img, model)
- try:
- torch.cuda.empty_cache()
- except:
- pass
- return img
-
- def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
- else:
- filename = path
- if filename is None or not os.path.exists(filename):
- return None
- if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
- )
- params = None
- else:
- model = net(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
- params = "params_ema"
-
- pretrained_model = torch.load(filename)
- if params is not None:
- model.load_state_dict(pretrained_model[params], strict=True)
- else:
- model.load_state_dict(pretrained_model, strict=True)
- if not cmd_opts.no_half:
- model = model.half()
- return model
-
-
-def upscale(
- img,
- model,
- tile=opts.SWIN_tile,
- tile_overlap=opts.SWIN_tile_overlap,
- window_size=8,
- scale=4,
-):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
- with torch.no_grad(), precision_scope("cuda"):
- _, _, h_old, w_old = img.size()
- h_pad = (h_old // window_size + 1) * window_size - h_old
- w_pad = (w_old // window_size + 1) * window_size - w_old
- img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
- img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
- output = output[..., : h_old * scale, : w_old * scale]
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
- if output.ndim == 3:
- output = np.transpose(
- output[[2, 1, 0], :, :], (1, 2, 0)
- ) # CHW-RGB to HCW-BGR
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return Image.fromarray(output, "RGB")
-
-
-def inference(img, model, tile, tile_overlap, window_size, scale):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- for w_idx in w_idx_list:
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py
deleted file mode 100644
index 863f42db..00000000
--- a/modules/swinir_model_arch.py
+++ /dev/null
@@ -1,867 +0,0 @@
-# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r""" SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(SwinIR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- if self.upscale == 4:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = SwinIR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py
deleted file mode 100644
index 0e28ae6e..00000000
--- a/modules/swinir_model_arch_v2.py
+++ /dev/null
@@ -1,1017 +0,0 @@
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, ' \
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape) \ No newline at end of file
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..68e1103c
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,341 @@
+import cv2
+import requests
+import os
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
+
+ pois = []
+
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.9, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
+
+def poi_average(pois, settings):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(weight and x / weight)
+ avg_y = round(weight and y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..88d68c76 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,7 +3,7 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset
+from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
@@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared
import re
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, latent=None, filename_text=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename
- self.latent = latent
self.filename_text = filename_text
- self.cond = None
- self.cond_text = None
+ self.latent_dist = latent_dist
+ self.latent_sample = latent_sample
+ self.cond = cond
+ self.cond_text = cond_text
+ self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -42,12 +45,19 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
-
- cond_model = shared.sd_model.cond_stage_model
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+
+
+ self.shuffle_tags = shuffle_tags
+ self.tag_drop_out = tag_drop_out
+
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
+ if shared.state.interrupted:
+ raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@@ -69,53 +79,94 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
- torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
- torchdata = torch.moveaxis(torchdata, 2, 0)
-
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
- init_latent = init_latent.to(devices.cpu)
-
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
-
- if include_cond:
+ torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
+ latent_sample = None
+
+ with devices.autocast():
+ latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
+
+ if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ latent_sampling_method = "once"
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "deterministic":
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+
+ if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- self.dataset.append(entry)
-
- assert len(self.dataset) > 1, "No images have been found in the dataset."
- self.length = len(self.dataset) * repeats // batch_size
+ if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
+ with devices.autocast():
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- self.initial_indexes = np.arange(len(self.dataset))
- self.indexes = None
- self.shuffle()
+ self.dataset.append(entry)
+ del torchdata
+ del latent_dist
+ del latent_sample
- def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.length = len(self.dataset)
+ assert self.length > 0, "No images have been found in the dataset."
+ self.batch_size = min(batch_size, self.length)
+ self.gradient_step = min(gradient_step, self.length // self.batch_size)
+ self.latent_sampling_method = latent_sampling_method
def create_text(self, filename_text):
text = random.choice(self.lines)
+ tags = filename_text.split(',')
+ if self.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > self.tag_drop_out]
+ if self.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
- res = []
-
- for j in range(self.batch_size):
- position = i * self.batch_size + j
- if position % len(self.indexes) == 0:
- self.shuffle()
-
- index = self.indexes[position % len(self.indexes)]
- entry = self.dataset[index]
-
- if entry.cond is None:
- entry.cond_text = self.create_text(entry.filename_text)
-
- res.append(entry)
-
- return res
+ entry = self.dataset[i]
+ if self.tag_drop_out != 0 or self.shuffle_tags:
+ entry.cond_text = self.create_text(entry.filename_text)
+ if self.latent_sampling_method == "random":
+ entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
+ return entry
+
+class PersonalizedDataLoader(DataLoader):
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
+
+
+class BatchLoader:
+ def __init__(self, data):
+ self.cond_text = [entry.cond_text for entry in data]
+ self.cond = [entry.cond for entry in data]
+ self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
+
+ def pin_memory(self):
+ self.latent_sample = self.latent_sample.pin_memory()
+ return self
+
+def collate_wrapper(batch):
+ return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
+
+ def pin_memory(self):
+ return self
+
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch) \ No newline at end of file
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
+from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
-
+ fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
- fontsize = 32
+
font = ImageFont.truetype(textfont, fontsize)
padding = 10
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,30 +4,37 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
def __iter__(self):
return self
@@ -52,7 +59,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return
try:
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,27 +1,26 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
import time
-from modules import shared, images
+from modules import shared, images, deepbooru
+from modules.paths import models_path
from modules.shared import opts, cmd_opts
-if cmd_opts.deepdanbooru:
- import modules.deepbooru as deepbooru
+from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
if process_caption_deepbooru:
- db_opts = deepbooru.create_deepbooru_opts()
- db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+ deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally:
@@ -29,88 +28,169 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
- deepbooru.release_process()
+ deepbooru.model.stop()
+def listfiles(dirname):
+ return os.listdir(dirname)
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
+
+
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+ caption = ""
+
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
+
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.model.tag_multi(image)
+
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
+
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
+
+
+def save_pic(image, index, params, existing_caption=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
+
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
- files = os.listdir(src)
+ files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index):
- caption = ""
-
- if process_caption:
- caption += shared.interrogator.generate_caption(image)
-
- if process_caption_deepbooru:
- if len(caption) > 0:
- caption += ", "
- caption += deepbooru.get_tags_from_process(image)
-
- filename_part = filename
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{subindex[0]}-{filename_part}"
- image.save(os.path.join(dst, f"{basename}.png"))
-
- if len(caption) > 0:
- with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
- file.write(caption)
-
- subindex[0] += 1
-
- def save_pic(image, index):
- save_pic_with_caption(image, index)
-
- if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
for index, imagefile in enumerate(tqdm.tqdm(files)):
- subindex = [0]
+ params.subindex = 0
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
- if shared.state.interrupted:
- break
-
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
-
- if process_split and is_tall:
- img = img.resize((width, height * img.height // img.width))
+ params.src = filename
- top = img.crop((0, 0, width, height))
- save_pic(top, index)
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
- bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((width * img.width // img.height, height))
-
- left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ if shared.state.interrupted:
+ break
- right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
+
+ process_default_resize = True
+
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
+ )
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_default_resize:
img = images.resize_image(1, img, width, height)
- save_pic(img, index)
+ save_pic(img, index, params, existing_caption=existing_caption)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..2250e41b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -23,9 +23,12 @@ class Embedding:
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -39,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -57,14 +67,17 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
+ self.expected_shape = -1
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
- ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+ # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
+ ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@@ -74,21 +87,26 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
+ def get_expected_shape(self):
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
+
+ def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
+ if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- data = []
-
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -96,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -119,9 +139,15 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- self.register_embedding(embedding, shared.sd_model)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -132,12 +158,13 @@ class EmbeddingDatabase:
process_file(fullfn, fn)
except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -153,19 +180,23 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
- embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
- ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
+ embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -180,7 +211,6 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if step % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -189,26 +219,52 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header:
csv_writer.writeheader()
- epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch = (step - 1) // epoch_len
+ epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
- "step": step + 1,
- "epoch": epoch + 1,
- "epoch_step": epoch_step + 1,
+ "step": step,
+ "epoch": epoch,
+ "epoch_step": epoch_step,
**values,
})
-
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
-
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
+ assert gradient_step > 0, "Gradient accumulation step must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -227,148 +283,241 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
- cond_model = shared.sd_model.cond_stage_model
-
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
- embedding.vec.requires_grad = True
-
- losses = torch.zeros((32,))
-
- last_saved_file = "<none>"
- last_saved_image = "<none>"
- embedding_yet_to_be_embedded = False
+ checkpoint = sd_models.select_checkpoint()
- ititial_step = embedding.step or 0
- if ititial_step > steps:
+ initial_step = embedding.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entries in pbar:
- embedding.step = i + ititial_step
-
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0]
-
- shared.state.current_image = image
-
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
-
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
-
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
-
- title = "<{}>".format(data.get('name', '???'))
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ pin_memory = shared.opts.pin_memory
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ latent_sampling_method = ds.latent_sampling_method
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
- image.save(last_saved_image)
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.first_stage_model.to(devices.cpu)
- last_saved_image += f", prompt: {preview_text}"
+ embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
- shared.state.job_no = embedding.step
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+ embedding_yet_to_be_embedded = False
- shared.state.textinfo = f"""
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
+
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+
+ loss = shared.sd_model(x, cond)[0] / gradient_step
+ del x
+
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
+
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ scaler.step(optimizer)
+ scaler.update()
+ embedding.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // steps_per_epoch
+ epoch_step = embedding.step % steps_per_epoch
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
<p>
-Loss: {losses.mean():.7f}<br/>
-Step: {embedding.step}<br/>
-Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
-
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ pass
+ finally:
+ pbar.leave = False
+ pbar.close()
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
+
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..35c4feef 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -18,15 +18,17 @@ def create_embedding(name, initialization_text, nvpt):
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
- return "Preprocessing finished.", ""
+ return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+ apply_optimizations = shared.opts.training_xattention_optimizations
try:
- sd_hijack.undo_optimizations()
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
- sd_hijack.apply_optimizations()
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2381347f..e189a899 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,12 +1,14 @@
import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules import sd_samplers
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -20,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_index=sampler_index,
+ sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
@@ -31,10 +33,13 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
- firstphase_width=firstphase_width if enable_hr else None,
- firstphase_height=firstphase_height if enable_hr else None,
+ hr_scale=hr_scale,
+ hr_upscaler=hr_upscaler,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
@@ -43,6 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
@@ -52,5 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
index 0abd177a..e4859020 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,62 +1,63 @@
-import base64
import html
-import io
import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
-import numpy as np
-import torch
-from PIL import Image, PngImagePlugin
-import piexif
-
import gradio as gr
-import gradio.utils
import gradio.routes
+import gradio.utils
+import numpy as np
+from PIL import Image, PngImagePlugin
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules.ui_components import FormRow, FormGroup, ToolButton
from modules.paths import script_path
+
from modules.shared import opts, cmd_opts, restricted_opts
-if cmd_opts.deepdanbooru:
- from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
-import modules.ldsr_model
-import modules.scripts
-import modules.gfpgan_model
+
import modules.codeformer_model
+import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.scripts
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.generation_parameters_copypaste import image_from_url_text
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
-
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
-if cmd_opts.ngrok != None:
+if cmd_opts.ngrok is not None:
import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...')
- ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
+ ngrok.connect(
+ cmd_opts.ngrok,
+ cmd_opts.port if cmd_opts.port is not None else 7860,
+ cmd_opts.ngrok_region
+ )
def gr_show(visible=True):
@@ -69,57 +70,34 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
+.wrap .z-20 svg { display:none!important; }
+.wrap .z-20::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
+.meta-text-center { display:none!important; }
"""
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
-art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
+clear_prompt_symbol = '\U0001F5D1' # 🗑️
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
return text
-
-def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
- filename = filedata["name"]
- tempdir = os.path.normpath(tempfile.gettempdir())
- normfn = os.path.normpath(filename)
- assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
-
- return Image.open(filename)
-
- if type(filedata) == list:
- if len(filedata) == 0:
- return None
-
- filedata = filedata[0]
-
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- filedata = base64.decodebytes(filedata.encode('utf-8'))
- image = Image.open(io.BytesIO(filedata))
- return image
-
-
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
-
return image_from_url_text(x[0])
-
def save_files(js_data, images, do_make_zip, index):
import csv
filenames = []
@@ -168,7 +146,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip
if do_make_zip:
@@ -181,85 +159,10 @@ def save_files(js_data, images, do_make_zip, index):
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)
- return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-
-
-def save_pil_to_file(pil_image, dir=None):
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in pil_image.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
- file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
- pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
-# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
-
-
-def wrap_gradio_call(func, extra_outputs=None):
- def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
- if run_memmon:
- shared.mem_mon.monitor()
- t = time.perf_counter()
-
- try:
- res = list(func(*args, **kwargs))
- except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {str(args)} {str(kwargs)}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
-
- shared.state.job = ""
- shared.state.job_count = 0
-
- if extra_outputs_array is None:
- extra_outputs_array = [None, '']
-
- res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
-
- elapsed = time.perf_counter() - t
- elapsed_m = int(elapsed // 60)
- elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
- if (elapsed_m > 0):
- elapsed_text = f"{elapsed_m}m "+elapsed_text
-
- if run_memmon:
- mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
- active_peak = mem_stats['active_peak']
- reserved_peak = mem_stats['reserved_peak']
- sys_peak = mem_stats['system_peak']
- sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
-
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
- else:
- vram_html = ''
-
- # last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
-
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.job_count = 0
-
- return tuple(res)
-
- return f
-
def calc_time_left(progress, threshold, label, force_display, showTime):
if progress == 0:
@@ -306,13 +209,8 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
- if opts.show_progress_every_n_steps > 0:
- if shared.parallel_processing_allowed:
-
- if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
- shared.state.current_image_sampling_step = shared.state.sampling_step
-
+ if opts.show_progress_every_n_steps != 0:
+ shared.state.set_current_image()
image = shared.state.current_image
if image is None:
@@ -339,13 +237,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part)
-def roll_artist(prompt):
- allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
- artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
-
- return prompt + ", " + artist.name if prompt != '' else artist.name
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -375,45 +266,41 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image):
- prompt = shared.interrogator.interrogate(image)
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image):
- prompt = get_deepbooru_tags(image)
+ prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt
-def create_seed_inputs():
- with gr.Row():
- with gr.Box():
- with gr.Row(elem_id='seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id='random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
- with gr.Box(elem_id='subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
- with gr.Row(visible=False) as seed_extra_row_1:
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
- with gr.Box():
- with gr.Row(elem_id='subseed_row'):
- subseed = gr.Number(label='Variation seed', value=-1)
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
-
- with gr.Row(visible=False) as seed_extra_row_2:
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -426,6 +313,17 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -497,22 +395,26 @@ def create_toprow(is_img2img):
)
with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
-
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
button_interrogate = None
button_deepbooru = None
if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
-
- if cmd_opts.deepdanbooru:
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
with gr.Row():
@@ -541,7 +443,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -569,6 +471,9 @@ def apply_setting(key, value):
if value is None:
return gr.update()
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
@@ -586,7 +491,7 @@ def apply_setting(key, value):
return
valtype = type(opts.data_labels[key].default)
- oldval = opts.data[key]
+ oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
@@ -595,102 +500,240 @@ def apply_setting(key, value):
return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
- for k, v in args.items():
- setattr(refresh_component, k, v)
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- return gr.update(**(args or {}))
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
+ return gr.update(**(args or {}))
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ else:
+ sp.Popen(["xdg-open", path])
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
- with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
- enable_hr = gr.Checkbox(label='Highres. fix', value=False)
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
- with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
- with gr.Row(equal_height=True):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
+ open_folder_button.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ html_log = gr.HTML()
- with gr.Column(variant='panel'):
+ generation_info = gr.Textbox(visible=False)
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
- with gr.Group():
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- send_to_img2img = gr.Button('Send to img2img')
- send_to_inpaint = gr.Button('Send to inpaint')
- send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
+ else:
+ html_info_x = gr.HTML()
+ html_info = gr.HTML()
+ html_log = gr.HTML()
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_index.save_to_config = True
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
+
+
+def create_ui():
+ import modules.img2img
+ import modules.txt2img
+
+ reload_javascript()
+
+ parameters_copypaste.reset()
+
+ modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+ modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
+ dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Row(elem_id='txt2img_progress_row'):
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="txt2img_progressbar")
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
+ setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+
+ elif category == "hires_fix":
+ with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
txt2img_prompt,
@@ -710,13 +753,15 @@ def create_ui(wrap_gradio_gpu_call):
width,
enable_hr,
denoising_strength,
- firstphase_width,
- firstphase_height,
+ hr_scale,
+ hr_upscaler,
] + custom_inputs,
+
outputs=[
txt2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -741,34 +786,6 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[hr_options],
)
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- txt2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
- )
-
- roll.click(
- fn=roll_artist,
- _js="update_txt2img_tokens",
- inputs=[
- txt2img_prompt,
- ],
- outputs=[
- txt2img_prompt,
- ]
- )
-
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -787,9 +804,11 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (firstphase_width, "First pass size-1"),
- (firstphase_height, "First pass size-2"),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
txt2img_preview_params = [
txt2img_prompt,
@@ -802,10 +821,13 @@ def create_ui(wrap_gradio_gpu_call):
height,
]
- token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
+ token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+
+ modules.scripts.scripts_current = modules.scripts.scripts_img2img
+ modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -818,88 +840,98 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
setup_progressbar(progressbar, img2img_preview, 'img2img')
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+
+ use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ if use_color_sketch:
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
- with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
- with gr.Row():
- mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
- with gr.Row():
- inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
- with gr.TabItem('Batch img2img', id='batch'):
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- with gr.Row():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
- with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
- with gr.Row():
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+ elif category == "cfg":
+ with FormGroup():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
- with gr.Group():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+ elif category == "checkboxes":
+ with FormRow(elem_id="img2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
- with gr.Group():
- custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
- with gr.Column(variant='panel'):
+ elif category == "scripts":
+ with FormGroup(elem_id="img2img_script_container"):
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui()
- with gr.Group():
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
-
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- img2img_send_to_img2img = gr.Button('Send to img2img')
- img2img_send_to_inpaint = gr.Button('Send to inpaint')
- img2img_send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+ parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -930,7 +962,7 @@ def create_ui(wrap_gradio_gpu_call):
)
img2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.img2img.img2img),
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -940,12 +972,14 @@ def create_ui(wrap_gradio_gpu_call):
img2img_prompt_style2,
init_img,
init_img_with_mask,
+ init_img_with_mask_orig,
init_img_inpaint,
init_mask_inpaint,
mask_mode,
steps,
sampler_index,
mask_blur,
+ mask_alpha,
inpainting_fill,
restore_faces,
tiling,
@@ -967,7 +1001,8 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[
img2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -981,39 +1016,10 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt],
)
- if cmd_opts.deepdanbooru:
- img2img_deepbooru.click(
- fn=interrogate_deepbooru,
- inputs=[init_img],
- outputs=[img2img_prompt],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- img2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
- )
-
- roll.click(
- fn=roll_artist,
- _js="update_img2img_tokens",
- inputs=[
- img2img_prompt,
- ],
- outputs=[
- img2img_prompt,
- ]
+ img2img_deepbooru.click(
+ fn=interrogate_deepbooru,
+ inputs=[init_img],
+ outputs=[img2img_prompt],
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
@@ -1038,6 +1044,8 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[prompt, negative_prompt, style1, style2],
)
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1054,66 +1062,62 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ (mask_blur, "Mask blur"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+
+ modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image'):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+ with gr.TabItem('Single Image', elem_id="extras_single_tab"):
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Batch Process'):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
- with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
- placeholder="A directory on the same machine where the server is running."
- )
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
- placeholder="Leave blank to save images to the default path."
- )
- show_extras_results = gr.Checkbox(label='Show result images', value=True)
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by'):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
- with gr.TabItem('Scale to'):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
with gr.Group():
with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
- with gr.Column(variant='panel'):
- result_images = gr.Gallery(label="Result", show_label=False)
- html_info_x = gr.HTML()
- html_info = gr.HTML()
- extras_send_to_img2img = gr.Button('Send to img2img')
- extras_send_to_inpaint = gr.Button('Send to inpaint')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
- open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+ with gr.Group():
+ upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
+ result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras),
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
_js="get_extras_tab_index",
inputs=[
dummy_component,
@@ -1133,6 +1137,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_upscaler_1,
extras_upscaler_2,
extras_upscaler_2_visibility,
+ upscale_before_face_fix,
],
outputs=[
result_images,
@@ -1140,19 +1145,11 @@ def create_ui(wrap_gradio_gpu_call):
html_info,
]
)
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
- extras_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[result_images],
- outputs=[init_img],
- )
-
- extras_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[result_images],
- outputs=[init_img_with_mask],
+ extras_image.change(
+ fn=modules.extras.clear_cache,
+ inputs=[], outputs=[]
)
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
@@ -1162,48 +1159,47 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(variant='panel'):
html = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
html2 = gr.HTML()
-
with gr.Row():
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+ parameters_copypaste.bind_buttons(buttons, image, generation_info)
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
- #images history
- images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
- }
-
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
- with gr.Blocks() as modelmerger_interface:
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
- custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
- save_as_half = gr.Checkbox(value=False, label="Save as float16")
+ create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
+ custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+
+ with gr.Row():
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
- with gr.Blocks() as train_interface:
+ with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@@ -1211,74 +1207,118 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
- new_embedding_name = gr.Textbox(label="Name")
- initialization_text = gr.Textbox(label="Initialization text", value="*")
- nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create embedding", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"):
- new_hypernetwork_name = gr.Textbox(label="Name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
- new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"):
- process_src = gr.Textbox(label='Source directory')
- process_dst = gr.Textbox(label='Destination directory')
- process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
+ process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
- process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
- process_caption = gr.Checkbox(label='Use BLIP for caption')
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
+ process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
+
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
+
+ with gr.Row(visible=False) as process_focal_crop_row:
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ with gr.Row():
+ interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
+ run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
+
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
+ process_focal_crop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_focal_crop],
+ outputs=[process_focal_crop_row],
+ )
with gr.Tab(label="Train"):
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
- batch_size = gr.Number(label='Batch size', value=1, precision=0)
- dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
- log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
- steps = gr.Number(label='Max steps', value=100000, precision=0)
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
- save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
- preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
+ with gr.Row():
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+ with gr.Row():
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
- train_embedding = gr.Button(value="Train Embedding", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
+
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+ script_callbacks.ui_train_tabs_callback(params)
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1296,6 +1336,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1309,8 +1350,12 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,
@@ -1327,10 +1372,18 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
- process_caption_deepbooru
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
+ process_focal_crop,
+ process_focal_crop_face_weight,
+ process_focal_crop_entropy_weight,
+ process_focal_crop_edges_weight,
+ process_focal_crop_debug,
],
outputs=[
ti_output,
@@ -1343,13 +1396,17 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
+ gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@@ -1368,13 +1425,17 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
+ gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@@ -1393,6 +1454,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
+ interrupt_preprocessing.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1417,142 +1484,108 @@ def create_ui(wrap_gradio_gpu_call):
if info.refresh is not None:
if is_quicksettings:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- with gr.Row(variant="compact"):
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
-
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
components = []
component_dict = {}
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- else:
- sp.Popen(["xdg-open", path])
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
def run_settings(*args):
- changed = 0
+ changed = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
- return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
+ assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- continue
+ if opts.set(key, value):
+ changed.append(key)
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- continue
-
- oldval = opts.data.get(key, None)
- opts.data[key] = value
-
- if oldval != value:
- if opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- changed += 1
-
- opts.save(shared.config_filename)
-
- return f'{changed} settings changed.', opts.dumpjson()
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- return gr.update(value=oldval), opts.dumpjson()
-
- oldval = opts.data.get(key, None)
- opts.data[key] = value
-
- if oldval != value:
- if opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
opts.save(shared.config_filename)
return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
- settings_submit = gr.Button(value="Apply settings", variant='primary')
- result = gr.HTML()
+ with gr.Row():
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
- settings_cols = 3
- items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ result = gr.HTML(elem_id="settings_result")
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
- quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = []
- cols_displayed = 0
- items_displayed = 0
previous_section = None
- column = None
- with gr.Row(elem_id="settings").style(equal_height=False):
+ current_tab = None
+ with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
- if previous_section != item.section:
- if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
- if column is not None:
- column.__exit__()
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
- column = gr.Column(variant='panel')
- column.__enter__()
+ if current_tab is not None:
+ current_tab.__exit__()
- items_displayed = 0
- cols_displayed += 1
+ current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab.__enter__()
previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
-
- if k in quicksettings_names:
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
else:
component = create_setting_component(k)
component_dict[k] = component
components.append(component)
- items_displayed += 1
- with gr.Row():
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ if current_tab is not None:
+ current_tab.__exit__()
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+ with gr.TabItem("Actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+
+ if os.path.exists("html/licenses.html"):
+ with open("html/licenses.html", encoding="utf8") as file:
+ with gr.TabItem("Licenses"):
+ gr.HTML(file.read(), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click(
fn=lambda: None,
@@ -1570,58 +1603,64 @@ Requested path was: {f}
def reload_scripts():
modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
- outputs=[],
- _js='function(){}'
+ outputs=[]
)
def request_restart():
shared.state.interrupt()
- settings_interface.gradio_ref.do_restart = True
+ shared.state.need_restart = True
restart_gradio.click(
fn=request_restart,
+ _js='restart_reload',
inputs=[],
outputs=[],
- _js='function(){restart_reload()}'
)
- if column is not None:
- column.__exit__()
-
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
- for i, k, item in quicksettings_list:
+ for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- settings_interface.gradio_ref = demo
+ parameters_copypaste.integrate_settings_paste_fields(component_dict)
+ parameters_copypaste.run_bind()
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
@@ -1631,11 +1670,15 @@ Requested path was: {f}
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ if os.path.exists("html/footer.html"):
+ with open("html/footer.html", encoding="utf8") as file:
+ gr.HTML(file.read(), elem_id="footer")
+
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
- fn=run_settings,
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
- outputs=[result, text_settings],
+ outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:
@@ -1647,6 +1690,17 @@ Requested path was: {f}
outputs=[component, text_settings],
)
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [getattr(opts, key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
@@ -1654,7 +1708,7 @@ Requested path was: {f}
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+ return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
return results
modelmerger_merge.click(
@@ -1667,6 +1721,7 @@ Requested path was: {f}
interp_amount,
save_as_half,
custom_name,
+ checkpoint_format,
],
outputs=[
submit_result,
@@ -1676,85 +1731,6 @@ Requested path was: {f}
component_dict['sd_model_checkpoint'],
]
)
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
- txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
- img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
- send_to_img2img.click(
- fn=lambda img, *args: (image_from_url_text(img),*args),
- _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img] + img2img_fields,
- )
-
- send_to_inpaint.click(
- fn=lambda x, *args: (image_from_url_text(x), *args),
- _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img_with_mask] + img2img_fields,
- )
-
- img2img_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[img2img_gallery],
- outputs=[init_img],
- )
-
- img2img_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[img2img_gallery],
- outputs=[init_img_with_mask],
- )
-
- send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[txt2img_gallery],
- outputs=[extras_image],
- )
-
- open_txt2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_img2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_extras_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
- inputs=[],
- outputs=[],
- )
-
- img2img_send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[img2img_gallery],
- outputs=[extras_image],
- )
-
- settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'sd_model_checkpoint': 'Model hash',
- }
-
- settings_paste_fields = [
- (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
- for k, v in settings_map.items()
- ]
-
- modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
- modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
-
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1774,7 +1750,7 @@ Requested path was: {f}
def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
- if getattr(obj,'custom_script_source',None) is not None:
+ if getattr(obj, 'custom_script_source', None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
if getattr(obj, 'do_not_save_to_config', False):
@@ -1829,13 +1805,14 @@ Requested path was: {f}
return demo
-def load_javascript(raw_response):
+def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
@@ -1844,7 +1821,7 @@ def load_javascript(raw_response):
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs):
- res = raw_response(*args, **kwargs)
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
@@ -1853,6 +1830,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
-reload_javascript()
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
diff --git a/modules/ui_components.py b/modules/ui_components.py
new file mode 100644
index 00000000..91eb0e3d
--- /dev/null
+++ b/modules/ui_components.py
@@ -0,0 +1,25 @@
+import gradio as gr
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class FormRow(gr.Row, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "row"
+
+
+class FormGroup(gr.Group, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "group"
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 00000000..eec9586f
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,328 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+import shutil
+import errno
+
+from modules import extensions, shared, paths
+
+
+available_extensions = {"extensions": []}
+
+
+def check_access():
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
+
+
+def apply_and_restart(disable_list, update_list):
+ check_access()
+
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ update = json.loads(update_list)
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+ update = set(update)
+
+ for ext in extensions.extensions:
+ if ext.name not in update:
+ continue
+
+ try:
+ ext.fetch_and_reset_hard()
+ except Exception:
+ print(f"Error getting updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+
+def check_updates():
+ check_access()
+
+ for ext in extensions.extensions:
+ if ext.remote is None:
+ continue
+
+ try:
+ ext.check_updates()
+ except Exception:
+ print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return extension_table()
+
+
+def extension_table():
+ code = f"""<!-- {time.time()} -->
+ <table id="extensions">
+ <thead>
+ <tr>
+ <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>URL</th>
+ <th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ for ext in extensions.extensions:
+ remote = ""
+ if ext.is_builtin:
+ remote = "built-in"
+ elif ext.remote:
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+
+ if ext.can_update:
+ ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
+ else:
+ ext_status = ext.status
+
+ code += f"""
+ <tr>
+ <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td>{remote}</td>
+ <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
+ </tr>
+ """
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ return code
+
+
+def normalize_git_url(url):
+ if url is None:
+ return ""
+
+ url = url.replace(".git", "")
+ return url
+
+
+def install_extension_from_url(dirname, url):
+ check_access()
+
+ assert url, 'No URL specified'
+
+ if dirname is None or dirname == "":
+ *parts, last_part = url.split('/')
+ last_part = normalize_git_url(last_part)
+
+ dirname = last_part
+
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+ normalized_url = normalize_git_url(url)
+ assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+
+ tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+
+ try:
+ shutil.rmtree(tmpdir, True)
+
+ repo = git.Repo.clone_from(url, tmpdir)
+ repo.remote().fetch()
+
+ try:
+ os.rename(tmpdir, target_dir)
+ except OSError as err:
+ # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
+ # Shouldn't cause any new issues at least but we probably want to handle it there too.
+ if err.errno == errno.EXDEV:
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
+ shutil.move(tmpdir, target_dir)
+ else:
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
+ raise(err)
+
+ import launch
+ launch.run_extension_installer(target_dir)
+
+ extensions.list_extensions()
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+ finally:
+ shutil.rmtree(tmpdir, True)
+
+
+def install_extension_from_index(url, hide_tags):
+ ext_table, message = install_extension_from_url(None, url)
+
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+
+ return code, ext_table, message
+
+
+def refresh_available_extensions(url, hide_tags):
+ global available_extensions
+
+ import urllib.request
+ with urllib.request.urlopen(url) as response:
+ text = response.read()
+
+ available_extensions = json.loads(text)
+
+ code, tags = refresh_available_extensions_from_data(hide_tags)
+
+ return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags):
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+
+ return code, ''
+
+
+def refresh_available_extensions_from_data(hide_tags):
+ extlist = available_extensions["extensions"]
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+ tags = available_extensions.get("tags", {})
+ tags_to_hide = set(hide_tags)
+ hidden = 0
+
+ code = f"""<!-- {time.time()} -->
+ <table id="available_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>Description</th>
+ <th>Action</th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ for ext in extlist:
+ name = ext.get("name", "noname")
+ url = ext.get("url", None)
+ description = ext.get("description", "")
+ extension_tags = ext.get("tags", [])
+
+ if url is None:
+ continue
+
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
+
+ if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ hidden += 1
+ continue
+
+ install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+
+ tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
+
+ code += f"""
+ <tr>
+ <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
+ <td>{html.escape(description)}</td>
+ <td>{install_code}</td>
+ </tr>
+
+ """
+
+ for tag in [x for x in extension_tags if x not in tags]:
+ tags[tag] = tag
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ if hidden > 0:
+ code += f"<p>Extension hidden: {hidden}</p>"
+
+ return code, list(tags)
+
+
+def create_ui():
+ import modules.ui
+
+ with gr.Blocks(analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.TabItem("Installed"):
+
+ with gr.Row():
+ apply = gr.Button(value="Apply and restart UI", variant="primary")
+ check = gr.Button(value="Check for updates")
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+
+ extensions_table = gr.HTML(lambda: extension_table())
+
+ apply.click(
+ fn=apply_and_restart,
+ _js="extensions_apply",
+ inputs=[extensions_disabled_list, extensions_update_list],
+ outputs=[],
+ )
+
+ check.click(
+ fn=check_updates,
+ _js="extensions_check",
+ inputs=[],
+ outputs=[extensions_table],
+ )
+
+ with gr.TabItem("Available"):
+ with gr.Row():
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+ with gr.Row():
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
+
+ install_result = gr.HTML()
+ available_extensions_table = gr.HTML()
+
+ refresh_available_extensions_button.click(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ inputs=[available_extensions_index, hide_tags],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
+ )
+
+ install_extension_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[extension_to_install, hide_tags],
+ outputs=[available_extensions_table, extensions_table, install_result],
+ )
+
+ hide_tags.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags],
+ outputs=[available_extensions_table, install_result]
+ )
+
+ with gr.TabItem("Install from URL"):
+ install_url = gr.Text(label="URL for extension's git repository")
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+ install_button = gr.Button(value="Install", variant="primary")
+ install_result = gr.HTML(elem_id="extension_install_result")
+
+ install_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ inputs=[install_dirname, install_url],
+ outputs=[extensions_table, install_result],
+ )
+
+ return ui
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
new file mode 100644
index 00000000..21945235
--- /dev/null
+++ b/modules/ui_tempdir.py
@@ -0,0 +1,82 @@
+import os
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+
+import gradio as gr
+
+from PIL import PngImagePlugin
+
+from modules import shared
+
+
+Savedfile = namedtuple("Savedfile", ["name"])
+
+
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+ if hasattr(gradio, 'temp_dirs'):
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+ return False
+
+
+def save_pil_to_file(pil_image, dir=None):
+ already_saved_as = getattr(pil_image, 'already_saved_as', None)
+ if already_saved_as and os.path.isfile(already_saved_as):
+ register_tmp_file(shared.demo, already_saved_as)
+
+ file_obj = Savedfile(already_saved_as)
+ return file_obj
+
+ if shared.opts.temp_dir != "":
+ dir = shared.opts.temp_dir
+
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
+def on_tmpdir_changed():
+ if shared.opts.temp_dir == "" or shared.demo is None:
+ return
+
+ os.makedirs(shared.opts.temp_dir, exist_ok=True)
+
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
+
+
+def cleanup_tmpdr():
+ temp_dir = shared.opts.temp_dir
+ if temp_dir == "" or not os.path.isdir(temp_dir):
+ return
+
+ for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for name in files:
+ _, extension = os.path.splitext(name)
+ if extension != ".png":
+ continue
+
+ filename = os.path.join(root, name)
+ os.remove(filename)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 6ab2fb40..231680cb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
@@ -52,14 +53,22 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
- def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = img.width * scale
- dest_h = img.height * scale
+ dest_w = int(img.width * scale)
+ dest_h = int(img.height * scale)
+
for i in range(3):
+ shape = (img.width, img.height)
+
+ img = self.do_upscale(img, selected_model)
+
+ if shape == (img.width, img.height):
+ break
+
if img.width >= dest_w and img.height >= dest_h:
break
- img = self.do_upscale(img, selected_model)
+
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
@@ -120,3 +129,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
+
+class UpscalerNearest(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Nearest"
+ self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file
diff --git a/modules/xlmr.py b/modules/xlmr.py
new file mode 100644
index 00000000..beab3fdf
--- /dev/null
+++ b/modules/xlmr.py
@@ -0,0 +1,137 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 768
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ self.pooler = lambda x: x[:,0]
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # last module outputs
+ sequence_output = outputs[0]
+
+
+ # project every module
+ sequence_output_ln = self.pre_LN(sequence_output)
+
+ # pooler
+ pooler_output = self.pooler(sequence_output_ln)
+ pooler_output = self.transformation(pooler_output)
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return {
+ 'pooler_output':pooler_output,
+ 'last_hidden_state':outputs.last_hidden_state,
+ 'hidden_states':outputs.hidden_states,
+ 'attentions':outputs.attentions,
+ 'projection_state':projection_state,
+ 'sequence_out': sequence_output
+ }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig \ No newline at end of file