aboutsummaryrefslogtreecommitdiff
path: root/modules/hashes.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-19 22:59:29 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-19 22:59:29 +0300
commit39ec4f06ffb2c26e1298b2c5d80874dc3fd693ac (patch)
tree7edd9b81f43b07826844e951241a4b4a15620257 /modules/hashes.py
parent87702febe08bddb9e8b952204c459927532ec6e6 (diff)
calculate hashes for Lora
add lora hashes to infotext when pasting infotext, use infotext's lora hashes to find local loras for <lora:xxx:1> entries whose hashes match loras the user has
Diffstat (limited to 'modules/hashes.py')
-rw-r--r--modules/hashes.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/modules/hashes.py b/modules/hashes.py
index 032120f4..8b7ea0ac 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -46,8 +46,8 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()
-def sha256_from_cache(filename, title):
- hashes = cache("hashes")
+def sha256_from_cache(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
if title not in hashes:
@@ -62,10 +62,10 @@ def sha256_from_cache(filename, title):
return cached_sha256
-def sha256(filename, title):
- hashes = cache("hashes")
+def sha256(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- sha256_value = sha256_from_cache(filename, title)
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
if sha256_value is not None:
return sha256_value
@@ -73,7 +73,11 @@ def sha256(filename, title):
return None
print(f"Calculating sha256 for {filename}: ", end='')
- sha256_value = calculate_sha256(filename)
+ if use_addnet_hash:
+ with open(filename, "rb") as file:
+ sha256_value = addnet_hash_safetensors(file)
+ else:
+ sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
@@ -86,6 +90,19 @@ def sha256(filename, title):
return sha256_value
+def addnet_hash_safetensors(b):
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()