aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/network.py2
-rw-r--r--extensions-builtin/Lora/network_full.py4
-rw-r--r--extensions-builtin/Lora/network_glora.py10
-rw-r--r--extensions-builtin/Lora/network_hada.py12
-rw-r--r--extensions-builtin/Lora/network_ia3.py2
-rw-r--r--extensions-builtin/Lora/network_lokr.py18
-rw-r--r--extensions-builtin/Lora/network_lora.py6
-rw-r--r--extensions-builtin/Lora/network_norm.py4
-rw-r--r--extensions-builtin/Lora/networks.py6
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/devices.py57
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/sd_models.py32
-rw-r--r--modules/sd_models_xl.py2
14 files changed, 124 insertions, 37 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 6021fd8d..a62e5eff 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -137,7 +137,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index bf6930e9..f221c95f 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
index 492d4870..efe5c681 100644
--- a/extensions-builtin/Lora/network_glora.py
+++ b/extensions-builtin/Lora/network_glora.py
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
- updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+ updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
index 5fcb0695..d95a0fd1 100644
--- a/extensions-builtin/Lora/network_hada.py
+++ b/extensions-builtin/Lora/network_hada.py
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
index 7edc4249..96faeaf3 100644
--- a/extensions-builtin/Lora/network_ia3.py
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 340acdab..fcdaeafd 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 26c0a72c..4cc40295 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
index ce450158..d25afcbb 100644
--- a/extensions-builtin/Lora/network_norm.py
+++ b/extensions-builtin/Lora/network_norm.py
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
- updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
- ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 7f814706..0170dbfb 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -394,12 +394,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index a9fb9bfa..088d5dea 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -118,3 +118,5 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
+parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
+parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 1d4eb563..d7c905c2 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -16,6 +16,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -60,8 +77,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -71,6 +87,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -91,12 +108,48 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+@contextlib.contextmanager
+def manual_autocast():
+ def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = next(self.parameters()).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = manual_cast_forward
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
+ return manual_autocast()
+
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_autocast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 8cdbafa5..636da679 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -308,8 +308,8 @@ def requirements_met(requirements_file):
def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 841402e8..a6c8b2fa 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -395,6 +395,38 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.cmd_opts.opt_unet_fp8_storage:
+ enable_fp8 = True
+ elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+
+ if enable_fp8:
+ devices.fp8 = True
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+
+ if devices.device == devices.cpu:
+ for module in model.model.diffusion_model.modules():
+ if isinstance(module, torch.nn.Conv2d):
+ module.to(torch.float8_e4m3fn)
+ elif isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+ else:
+ model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 01123321..11259a36 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -93,7 +93,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()