diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 14 | ||||
-rw-r--r-- | modules/initialize_util.py | 6 | ||||
-rw-r--r-- | modules/processing.py | 5 | ||||
-rw-r--r-- | modules/sd_hijack.py | 4 | ||||
-rw-r--r-- | modules/sd_models.py | 21 | ||||
-rw-r--r-- | modules/sd_models_config.py | 5 | ||||
-rw-r--r-- | modules/shared_options.py | 3 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 74 | ||||
-rw-r--r-- | modules/ui.py | 2 | ||||
-rw-r--r-- | modules/ui_settings.py | 24 | ||||
-rw-r--r-- | modules/xlmr_m18.py | 164 |
11 files changed, 252 insertions, 70 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 905ef9c9..09083874 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -103,7 +102,8 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - + if isinstance(image, str): + return image if opts.samples_format.lower() == 'png': use_metadata = False metadata = PngImagePlugin.PngInfo() @@ -540,12 +540,12 @@ class Api: return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -565,7 +565,7 @@ class Api: def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2894eee4..2e9b6d89 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -150,10 +150,14 @@ def dumpstacks(): def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
+
+ from modules import shared
+
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
- dumpstacks()
+ if shared.opts.dump_stacks_on_signal:
+ dumpstacks()
os._exit(0)
diff --git a/modules/processing.py b/modules/processing.py index 36bc94f7..40598f5c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None:
p.scripts.before_process(p)
- stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data}
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@@ -960,6 +960,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: state.nextjob()
+ if not infotexts:
+ infotexts.append(Processed(p, []).infotext(p, 0))
+
p.color_corrections = None
index_of_first_image = 0
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 22a1eb5c..bc5fbcd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -211,7 +211,7 @@ class StableDiffusionModelHijack: else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
diff --git a/modules/sd_models.py b/modules/sd_models.py index 7f8502f5..3b6cdea1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,6 @@ import collections
import os.path
import sys
-import gc
import threading
import torch
@@ -357,12 +356,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
-
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 08dd03f1..deab2f6e 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
+config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
@@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+ if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
+ return config_alt_diffusion_m18
return config_alt_diffusion
return config_default
diff --git a/modules/shared_options.py b/modules/shared_options.py index a674d3da..32bf7353 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -62,6 +62,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
+
+ "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -110,6 +112,7 @@ options_templates.update(options_section(('system', "System"), { "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
+ "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
}))
options_templates.update(options_section(('API', "API"), {
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 401a0a2a..04dda585 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ class EmbeddingDatabase: else:
return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
- vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
- shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
- vectors = data['clip_g'].shape[0]
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vectors
- embedding.shape = shape
- embedding.filename = path
- embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
+ embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn
+def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
+ if 'string_to_param' in data: # textual inversion embeddings
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vectors
+ embedding.shape = shape
+
+ if filepath:
+ embedding.filename = filepath
+ embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
+
+ return embedding
+
+
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
diff --git a/modules/ui.py b/modules/ui.py index 3d1f5285..bcf39199 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1296,7 +1296,7 @@ def create_ui(): loadsave.setup_ui()
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 74a3aef3..e054d00a 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,6 +1,6 @@ import gradio as gr
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
@@ -177,8 +177,8 @@ class UiSettings: download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row():
- unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
with gr.Row():
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
@@ -194,16 +194,26 @@ class UiSettings: self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ def call_func_and_return_text(func, text):
+ def handler():
+ t = timer.Timer()
+ func()
+ t.record(text)
+
+ return f'{text} in {t.total:.1f}s'
+
+ return handler
+
unload_sd_model.click(
- fn=sd_models.unload_model_weights,
+ fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
reload_sd_model.click(
- fn=sd_models.reload_model_weights,
+ fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
request_notifications.click(
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py new file mode 100644 index 00000000..a727e865 --- /dev/null +++ b/modules/xlmr_m18.py @@ -0,0 +1,164 @@ +from transformers import BertPreTrainedModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 1024 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + # self.pooler = lambda x: x[:,0] + # self.post_init() + + self.has_pre_transformation = True + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # # last module outputs + # sequence_output = outputs[0] + + + # # project every module + # sequence_output_ln = self.pre_LN(sequence_output) + + # # pooler + # pooler_output = self.pooler(sequence_output_ln) + # pooler_output = self.transformation(pooler_output) + # projection_state = self.transformation(outputs.last_hidden_state) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return { + "projection_state": projection_state2, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + else: + projection_state = self.transformation(outputs.last_hidden_state) + return { + "projection_state": projection_state, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + + + # return { + # 'pooler_output':pooler_output, + # 'last_hidden_state':outputs.last_hidden_state, + # 'hidden_states':outputs.hidden_states, + # 'attentions':outputs.attentions, + # 'projection_state':projection_state, + # 'sequence_out': sequence_output + # } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig |