aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAngelBottomless <35677394+aria1th@users.noreply.github.com>2022-11-04 15:51:09 +0900
committerGitHub <noreply@github.com>2022-11-04 15:51:09 +0900
commit179702adc40cc8d9c97ae883ee9d0f7c79076047 (patch)
tree34b166971e0d8b5b2b8a7ec631d7395072f3f218 /modules
parent0d07cbfa15d34294a4fa22d74359cdd6fe2f799c (diff)
parentf2b69709eaff88fc3a2bd49585556ec0883bf5ea (diff)
Merge branch 'AUTOMATIC1111:master' into force-push-patch-13
Diffstat (limited to 'modules')
-rw-r--r--modules/api/models.py1
-rw-r--r--modules/hypernetworks/hypernetwork.py36
-rw-r--r--modules/masking.py2
-rw-r--r--modules/processing.py31
-rw-r--r--modules/scripts.py34
-rw-r--r--modules/sd_models.py5
-rw-r--r--modules/shared.py11
-rw-r--r--modules/ui.py28
8 files changed, 82 insertions, 66 deletions
diff --git a/modules/api/models.py b/modules/api/models.py
index 9ee42a17..68fb45c6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -131,6 +131,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index a11e01d6..6e1a10cf 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -35,7 +35,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -48,8 +49,8 @@ class HypernetworkModule(torch.nn.Module):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # Add an activation func
- if activation_func == "linear" or activation_func is None:
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
@@ -60,8 +61,8 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout expect last layer
- if use_dropout and i < len(layer_structure) - 3:
+ # Add dropout except last layer
+ if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
@@ -75,7 +76,7 @@ class HypernetworkModule(torch.nn.Module):
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
- normal_(b, mean=0.0, std=0.005)
+ normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
@@ -127,7 +128,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
@@ -139,11 +140,15 @@ class Hypernetwork:
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
+ self.activate_output = activate_output
+ self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
@@ -171,7 +176,9 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
-
+ state_dict['activate_output'] = self.activate_output
+ state_dict['last_layer_dropout'] = self.last_layer_dropout
+
torch.save(state_dict, filename)
def load(self, filename):
@@ -191,12 +198,17 @@ class Hypernetwork:
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
+ self.activate_output = state_dict.get('activate_output', True)
+ print(f"Activate last layer is set to {self.activate_output}")
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
diff --git a/modules/masking.py b/modules/masking.py
index fd8d9241..a5c4d2da 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
- desired_height = (x2 - x1) * ratio_processing
+ desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2
diff --git a/modules/processing.py b/modules/processing.py
index 3a364b5f..a46e592d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,11 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -426,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+ setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
- opts.data[k] = v
+ setattr(opts, k, v)
return res
@@ -674,6 +666,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
for i in range(samples.shape[0]):
save_intermediate(samples, i)
@@ -700,14 +699,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- image_conditioning = self.txt2img_image_conditioning(x)
-
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
diff --git a/modules/scripts.py b/modules/scripts.py
index 533db45c..28ce07f4 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -18,6 +18,9 @@ class Script:
args_to = None
alwayson = False
+ """A gr.Group component that has all script's UI inside it"""
+ group = None
+
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
@@ -218,8 +221,6 @@ class ScriptRunner:
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- if not script.alwayson:
- control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
@@ -229,40 +230,41 @@ class ScriptRunner:
script.args_to = len(inputs)
for script in self.alwayson_scripts:
- with gr.Group():
+ with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson)
+ script.group = group
+
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=False) as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ script.group = group
def select_script(script_index):
- if 0 < script_index <= len(self.selectable_scripts):
- script = self.selectable_scripts[script_index-1]
- args_from = script.args_from
- args_to = script.args_to
- else:
- args_from = 0
- args_to = 0
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
+ """called when an initial value is set from ui-config.json to show script's UI components"""
+
if title == 'None':
return
+
script_index = self.titles.index(title)
- script = self.selectable_scripts[script_index]
- for i in range(script.args_from, script.args_to):
- inputs[i].visible = True
+ self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
+
dropdown.change(
fn=select_script,
inputs=[dropdown],
- outputs=inputs
+ outputs=[script.group for script in self.selectable_scripts]
)
return inputs
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 5075fadb..ae427a5c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -204,8 +204,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
else:
- vae_name = sd_vae.get_filename(vae_file)
- print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
+ vae_name = sd_vae.get_filename(vae_file) if vae_file else None
+ vae_message = f" with {vae_name} VAE" if vae_name else ""
+ print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
checkpoints_loaded.move_to_end(checkpoint_key)
model.load_state_dict(checkpoints_loaded[checkpoint_key])
diff --git a/modules/shared.py b/modules/shared.py
index 7ecb40d8..4d6e1c8b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -397,6 +397,15 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data or key in self.data_labels:
+ assert not cmd_opts.freeze_settings, "changing settings is disabled"
+
+ comp_args = opts.data_labels[key].component_args
+ if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
self.data[key] = value
return
@@ -413,6 +422,8 @@ class Options:
return super(Options, self).__getattribute__(item)
def save(self, filename):
+ assert not cmd_opts.freeze_settings, "saving settings is disabled"
+
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
diff --git a/modules/ui.py b/modules/ui.py
index 2609857e..633b56ef 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1052,6 +1052,8 @@ def create_ui(wrap_gradio_gpu_call):
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
@@ -1079,8 +1081,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
@@ -1182,8 +1182,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1438,8 +1438,6 @@ def create_ui(wrap_gradio_gpu_call):
def run_settings(*args):
changed = 0
- assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
-
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1448,15 +1446,9 @@ def create_ui(wrap_gradio_gpu_call):
if comp == dummy_component:
continue
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- continue
-
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- continue
-
oldval = opts.data.get(key, None)
- opts.data[key] = value
+
+ setattr(opts, key, value)
if oldval != value:
if opts.data_labels[key].onchange is not None:
@@ -1469,17 +1461,15 @@ def create_ui(wrap_gradio_gpu_call):
return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
- assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
-
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ try:
+ setattr(opts, key, value)
+ except Exception:
return gr.update(value=oldval), opts.dumpjson()
- opts.data[key] = value
-
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()