aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/cache.py5
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/interrogate.py5
-rw-r--r--modules/lowvram.py18
-rw-r--r--modules/realesrgan_model.py1
-rw-r--r--modules/sd_models.py19
-rw-r--r--modules/sd_models_types.py31
-rw-r--r--modules/sd_unet.py2
-rw-r--r--modules/sd_vae.py4
-rw-r--r--modules/shared.py7
10 files changed, 71 insertions, 22 deletions
diff --git a/modules/cache.py b/modules/cache.py
index a7cd3aeb..ff26a213 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -30,9 +30,12 @@ def dump_cache():
time.sleep(1)
with cache_lock:
- with open(cache_filename, "w", encoding="utf8") as file:
+ cache_filename_tmp = cache_filename + "-"
+ with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
+ os.replace(cache_filename_tmp, cache_filename)
+
dump_cache_after = None
dump_cache_thread = None
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 9f8e5b30..f0f361bd 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -35,6 +35,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
+parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index a3ae1dd5..3045560d 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -186,9 +186,8 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- devices.torch_gc()
+ lowvram.send_everything_to_cpu()
+ devices.torch_gc()
self.load()
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 96f52b7b..45701046 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,5 +1,5 @@
import torch
-from modules import devices
+from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
module_in_gpu = None
+def is_needed(sd_model):
+ return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
+
+
+def apply(sd_model):
+ enable = is_needed(sd_model)
+ shared.parallel_processing_allowed = not enable
+
+ if enable:
+ setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
+ else:
+ sd_model.lowvram = False
+
+
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
def is_enabled(sd_model):
- return getattr(sd_model, 'lowvram', False)
+ return sd_model.lowvram
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 0700b853..02841c30 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
+ device=self.device,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 27d15e66..547e93c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -345,6 +345,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
vae = model.first_stage_model
@@ -362,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -517,7 +518,7 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -525,17 +526,17 @@ def send_model_to_cpu(m):
devices.torch_gc()
-def model_target_device():
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+def model_target_device(m):
+ if lowvram.is_needed(m):
return devices.cpu
else:
return devices.device
def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
- else:
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -601,7 +602,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
'': torch.float16,
}
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -743,7 +744,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
new file mode 100644
index 00000000..5ffd2f4f
--- /dev/null
+++ b/modules/sd_models_types.py
@@ -0,0 +1,31 @@
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from modules.sd_models import CheckpointInfo
+
+
+class WebuiSdModel(LatentDiffusion):
+ """This class is not actually instantinated, but its fields are created and fieeld by webui"""
+
+ lowvram: bool
+ """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
+
+ sd_model_hash: str
+ """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
+
+ sd_model_checkpoint: str
+ """path to the file on disk that model weights were obtained from"""
+
+ sd_checkpoint_info: 'CheckpointInfo'
+ """structure with additional information about the file with model's weights"""
+
+ is_sdxl: bool
+ """True if the model's architecture is SDXL"""
+
+ is_sd2: bool
+ """True if the model's architecture is SD 2.x"""
+
+ is_sd1: bool
+ """True if the model's architecture is SD 1.x"""
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 6d708ad2..5525cfbc 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -47,7 +47,7 @@ def apply_unet(option=None):
if current_unet_option is None:
current_unet = None
- if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index ee118656..669097da 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -263,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if sd_model.lowvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
@@ -275,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
print("VAE weights loaded.")
diff --git a/modules/shared.py b/modules/shared.py
index 0c57b712..63661939 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,16 +2,15 @@ import sys
import gradio as gr
-from modules import shared_cmd_options, shared_gradio_themes, options, shared_items
+from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
-from ldm.models.diffusion.ddpm import LatentDiffusion
from modules import util
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
-parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
@@ -40,7 +39,7 @@ options_templates = None
opts = None
restricted_opts = None
-sd_model: LatentDiffusion = None
+sd_model: sd_models_types.WebuiSdModel = None
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""