aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorKeavon Chambers <keavon@keavon.com>2022-11-19 10:34:31 -0800
committerGitHub <noreply@github.com>2022-11-19 10:34:31 -0800
commit2f90496b19cd9c512633742db97b072a7075f017 (patch)
tree60d2dddd69172d9b5cf58c8da2bd64c61132f4fa /modules/textual_inversion/dataset.py
parenta258fd60dbe2d68325339405a2aa72816d06d2fd (diff)
parent47a44c7e421b98ca07e92dbf88769b04c9e28f86 (diff)
Merge branch 'master' into cors-regex
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index ad726577..eb75c376 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -98,7 +98,12 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
+ tags = filename_text.split(',')
+ if shared.opts.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
+ if shared.opts.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
return text
def __len__(self):