aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorunknown <mcgpapu@gmail.com>2023-02-03 20:39:42 -0600
committerunknown <mcgpapu@gmail.com>2023-02-03 20:39:42 -0600
commit5e1f4f7464e560a2468812fc9d5cb38659cff86c (patch)
tree6b3e6676384fae53f3359aeea9ac51d32a5affd6 /modules
parentade40aa1a0605ba4aa3adc734ffb2b5121729d03 (diff)
parent226d840e84c5f306350b0681945989b86760e616 (diff)
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui into gamepad
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py49
-rw-r--r--modules/shared.py2
-rw-r--r--modules/ui.py4
3 files changed, 18 insertions, 37 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 655ca1d3..919048d0 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -2,6 +2,7 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
+from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -156,36 +157,7 @@ def test_for_nans(x, where):
raise NansException(message)
-# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-orig_tensor_to = torch.Tensor.to
-def tensor_to_fix(self, *args, **kwargs):
- if self.device.type != 'mps' and \
- ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
- (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
- self = self.contiguous()
- return orig_tensor_to(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
-orig_layer_norm = torch.nn.functional.layer_norm
-def layer_norm_fix(*args, **kwargs):
- if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
- args = list(args)
- args[0] = args[0].contiguous()
- return orig_layer_norm(*args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
-orig_tensor_numpy = torch.Tensor.numpy
-def numpy_fix(self, *args, **kwargs):
- if self.requires_grad:
- self = self.detach()
- return orig_tensor_numpy(self, *args, **kwargs)
-
-
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-orig_cumsum = torch.cumsum
-orig_Tensor_cumsum = torch.Tensor.cumsum
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
@@ -199,11 +171,20 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
+
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
- torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
+
diff --git a/modules/shared.py b/modules/shared.py
index 69634fd8..5600d480 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -327,7 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
- "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
diff --git a/modules/ui.py b/modules/ui.py
index f910c582..5e34fb07 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -479,8 +479,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
@@ -757,8 +757,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")