aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordiscus0434 <discus0434@gmail.com>2022-10-22 11:07:00 +0000
committerdiscus0434 <discus0434@gmail.com>2022-10-22 11:07:00 +0000
commit0e8ca8e7af05be22d7d2c07a47c3c7febe0f0ab6 (patch)
treeaf463262310e211004f523d906456e16f1c80e5e
parent6a02841fff12892eedc979a57999a2d4fc4a9ed4 (diff)
add dropout
-rw-r--r--modules/hypernetworks/hypernetwork.py68
-rw-r--r--modules/hypernetworks/ui.py10
-rw-r--r--modules/ui.py43
3 files changed, 70 insertions, 51 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 905cbeef..e493f366 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,47 +1,60 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
+import modules.textual_inversion.dataset
import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
- "swish": torch.nn.Hardswish}
-
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
+ assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+
linears = []
for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # if skip_first_layer because first parameters potentially contain negative values
- # if i < 1: continue
- if activation_func in HypernetworkModule.activation_dict:
- linears.append(HypernetworkModule.activation_dict[activation_func]())
+
+ # Add an activation func
+ if activation_func == "linear":
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
else:
- print("Invalid key {} encountered as activation function!".format(activation_func))
- # if use_dropout:
- # linears.append(torch.nn.Dropout(p=0.3))
+ raise NotImplementedError(
+ "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+ )
+
+ # Add dropout
+ if use_dropout:
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ # Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -93,7 +106,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -101,13 +114,14 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
- self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -129,8 +143,9 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
- state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -144,8 +159,9 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 1a5a27d8..5f6f17b6 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,14 +3,13 @@ import os
import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
- add_layer_norm=add_layer_norm,
activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
)
hypernet.save(fn)
diff --git a/modules/ui.py b/modules/ui.py
index 716f14b8..d4b32c05 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,43 +5,44 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
-import gradio as gr
-import gradio.utils
-import gradio.routes
-
-from modules import sd_hijack, sd_models, localization
+from modules import localization, sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import cmd_opts, opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
-import modules.textual_inversion.ui
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
with gr.Row():
with gr.Column(scale=3):
@@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name,
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
- new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,