aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/network.py2
-rw-r--r--extensions-builtin/Lora/network_full.py4
-rw-r--r--extensions-builtin/Lora/network_glora.py10
-rw-r--r--extensions-builtin/Lora/network_hada.py12
-rw-r--r--extensions-builtin/Lora/network_ia3.py2
-rw-r--r--extensions-builtin/Lora/network_lokr.py18
-rw-r--r--extensions-builtin/Lora/network_lora.py6
-rw-r--r--extensions-builtin/Lora/network_norm.py4
-rw-r--r--extensions-builtin/Lora/networks.py18
-rw-r--r--modules/devices.py60
-rw-r--r--modules/errors.py4
-rw-r--r--modules/initialize_util.py2
-rw-r--r--modules/launch_utils.py6
-rw-r--r--modules/sd_models.py49
-rw-r--r--modules/sd_models_xl.py2
-rw-r--r--modules/shared_options.py2
-rw-r--r--scripts/xyz_grid.py1
-rw-r--r--webui-macos-env.sh2
18 files changed, 158 insertions, 46 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 6021fd8d..a62e5eff 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -137,7 +137,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index bf6930e9..f221c95f 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
index 492d4870..efe5c681 100644
--- a/extensions-builtin/Lora/network_glora.py
+++ b/extensions-builtin/Lora/network_glora.py
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
- updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+ updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
index 5fcb0695..d95a0fd1 100644
--- a/extensions-builtin/Lora/network_hada.py
+++ b/extensions-builtin/Lora/network_hada.py
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
index 7edc4249..96faeaf3 100644
--- a/extensions-builtin/Lora/network_ia3.py
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 340acdab..fcdaeafd 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 26c0a72c..4cc40295 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
index ce450158..d25afcbb 100644
--- a/extensions-builtin/Lora/network_norm.py
+++ b/extensions-builtin/Lora/network_norm.py
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
- updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
- ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 7f814706..d22ed843 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ if getattr(self, 'fp16_weight', None) is None:
+ weight = self.weight
+ bias = self.bias
+ else:
+ weight = self.fp16_weight.clone().to(self.weight.device)
+ bias = getattr(self, 'fp16_bias', None)
+ if bias is not None:
+ bias = bias.clone().to(self.bias.device)
+ updown, ex_bias = module.calc_updown(weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
diff --git a/modules/devices.py b/modules/devices.py
index ea1f712f..c956207f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -23,6 +23,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -73,8 +90,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -84,6 +100,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -104,12 +121,51 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = next(self.parameters()).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
+@contextlib.contextmanager
+def manual_cast():
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = manual_cast_forward
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
+ return manual_cast()
+
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_cast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/errors.py b/modules/errors.py
index eb234a83..c534a5d6 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -107,8 +107,8 @@ def check_versions():
import torch
import gradio
- expected_torch_version = "2.0.0"
- expected_xformers_version = "0.0.20"
+ expected_torch_version = "2.1.0"
+ expected_xformers_version = "0.0.22.post7"
expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2e9b6d89..b6767138 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 29506f24..2c54e2a0 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -314,8 +314,8 @@ def requirements_met(requirements_file):
def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
@@ -338,7 +338,7 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9355f1e1..d0046f88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -348,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -404,6 +422,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -746,7 +786,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -758,11 +798,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 01123321..11259a36 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -93,7 +93,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
diff --git a/modules/shared_options.py b/modules/shared_options.py
index e5de0d01..a860e355 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -206,6 +206,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 0dc255bc..b2250c04 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -270,6 +270,7 @@ axis_options = [
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
+ AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
]
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index 24bc5c42..db7e8b1a 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -11,7 +11,7 @@ fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
+export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0"
export PYTORCH_ENABLE_MPS_FALLBACK=1
####################################################################