aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py58
1 files changed, 42 insertions, 16 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 4f7613a1..36f643e1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -2,6 +2,8 @@ import collections
import os.path
import sys
import gc
+import threading
+
import torch
import re
import safetensors.torch
@@ -45,7 +47,7 @@ class CheckpointInfo:
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
@@ -67,7 +69,7 @@ class CheckpointInfo:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None:
return
@@ -404,13 +406,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+ def __init__(self):
+ self.sd_model = None
+ self.lock = threading.Lock()
+
+ def get_sd_model(self):
+ if self.sd_model is None:
+ with self.lock:
+ try:
+ load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load", file=sys.stderr)
+ self.sd_model = None
+
+ return self.sd_model
+
+ def set_sd_model(self, v):
+ self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- if shared.sd_model:
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
gc.collect()
devices.torch_gc()
@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
timer.record("hijack")
sd_model.eval()
- shared.sd_model = sd_model
+ model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_info = info or select_checkpoint()
if not sd_model:
- sd_model = shared.sd_model
+ sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
- return shared.sd_model
+ return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model
+
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
- if shared.sd_model:
-
- # shared.sd_model.cond_stage_model.to(devices.cpu)
- # shared.sd_model.first_stage_model.to(devices.cpu)
- shared.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()