aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sd_models.py6
-rw-r--r--modules/sd_vae.py3
-rw-r--r--modules/xpu_specific.py22
3 files changed, 23 insertions, 8 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 50bc209e..2c045771 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("script callbacks")
-
if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("script callbacks")
+
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 31306d8b..43687e48 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
if not sd_model.lowvram:
sd_model.to(devices.device)
+ script_callbacks.model_loaded_callback(sd_model)
+
print("VAE weights loaded.")
return sd_model
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index f7687a66..2971dbc3 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length
@@ -92,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
return torch.reshape(result, (*N, L, Ev))
+def is_xpu_device(device: str | torch.device = None):
+ if device is None:
+ return False
+ if isinstance(device, str):
+ return device.startswith("xpu")
+ return device.type == "xpu"
+
+
if has_xpu:
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
- CondFunc('torch.Generator',
- lambda orig_func, device=None: torch.xpu.Generator(device),
- lambda orig_func, device=None: device is not None and device.type == "xpu")
+ try:
+ # torch.Generator supports "xpu" device since 2.1
+ torch.Generator("xpu")
+ except RuntimeError:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',