From 3d8b1af6beb9015f6b3573661d8ed00275f6129f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:09:33 +0800 Subject: Support string_to_param nested dict format: bundle_emb.EMBNAME.string_to_param.KEYNAME --- extensions-builtin/Lora/networks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 652b8ebe..ab3517d8 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -157,7 +157,11 @@ def load_network(name, network_on_disk): if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) - emb_dict[vec_name] = weight + if vec_name.split('.')[0] == 'string_to_param': + _, k2 = vec_name.split('.', 1) + emb_dict['string_to_param'] = {k2: weight} + else: + emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) @@ -301,6 +305,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: emb_db.register_embedding(embedding, shared.sd_model) + print(f'registered bundle embedding: {embedding.name}') else: emb_db.skipped_embeddings[name] = embedding -- cgit v1.2.1