From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: Use less RAM when creating models --- modules/sd_disable_initialization.py | 106 +++++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 5 deletions(-) (limited to 'modules/sd_disable_initialization.py') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6..695c5736 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,31 @@ import open_clip import torch import transformers.utils.hub +from modules import shared -class DisableInitialization: + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +44,7 @@ class DisableInitialization: """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,8 +109,81 @@ class DisableInitialization: self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - for obj, field, original in self.replaced: - setattr(obj, field, original) + self.restore() - self.replaced.clear() +class InitializeOnMeta(ReplaceHelper): + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + def set_device(x): + x["device"] = "meta" + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + original(self, state_dict, prefix, *args, **kwargs) + + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() -- cgit v1.2.1 From 86221269f98ef9b21a6e6c9d04b86e2fb5cb33d3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 09:55:35 +0300 Subject: RAM optimization round 2 --- modules/sd_disable_initialization.py | 51 +++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 7 deletions(-) (limited to 'modules/sd_disable_initialization.py') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 695c5736..719eeb93 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -168,22 +168,59 @@ class LoadStateDictOnMeta(ReplaceHelper): device = self.device def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): - params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + used_param_keys = [] + + for name, param in self._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) - for name, param in params: if param.is_meta: - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + dtype = sd_param.dtype if sd_param is not None else param.dtype + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in self._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) original(self, state_dict, prefix, *args, **kwargs) - for name, _ in params: - key = prefix + name - if key in sd: - del sd[key] + for key in used_param_keys: + state_dict.pop(key, None) + + def load_state_dict(original, self, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict == sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + original(self, state_dict, strict=strict) + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) def __exit__(self, exc_type, exc_val, exc_tb): self.restore() -- cgit v1.2.1 From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_disable_initialization.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'modules/sd_disable_initialization.py') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 719eeb93..8863107a 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ``` """ - def __init__(self, state_dict, device): + def __init__(self, state_dict, device, weight_dtype_conversion=None): super().__init__() self.state_dict = state_dict self.device = device + self.weight_dtype_conversion = weight_dtype_conversion or {} + self.default_dtype = self.weight_dtype_conversion.get('') + + def get_weight_dtype(self, key): + key_first_term, _ = key.split('.', 1) + return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) def __enter__(self): if shared.cmd_opts.disable_model_loading_ram_optimization: @@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): sd = self.state_dict device = self.device - def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - for name, param in self._parameters.items(): + for name, param in module._parameters.items(): if param is None: continue key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) used_param_keys.append(key) if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - for name in self._buffers: + for name in module._buffers: key = prefix + name sd_param = sd.pop(key, None) @@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): state_dict[key] = sd_param used_param_keys.append(key) - original(self, state_dict, prefix, *args, **kwargs) + original(module, state_dict, prefix, *args, **kwargs) for key in used_param_keys: state_dict.pop(key, None) - def load_state_dict(original, self, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. @@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if state_dict == sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - original(self, state_dict, strict=strict) + original(module, state_dict, strict=strict) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) -- cgit v1.2.1