aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/learn_schedule.py
blob: db7202712d01fa59a13912b4887cbd44970abb73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

class LearnSchedule:
    def __init__(self, learn_rate, max_steps, cur_step=0):
        pairs = learn_rate.split(',')
        self.rates = []
        self.it = 0
        self.maxit = 0
        for i, pair in enumerate(pairs):
            tmp = pair.split(':')
            if len(tmp) == 2:
                step = int(tmp[1])
                if step > cur_step:
                    self.rates.append((float(tmp[0]), min(step, max_steps)))
                    self.maxit += 1
                    if step > max_steps:
                        return
                elif step == -1:
                    self.rates.append((float(tmp[0]), max_steps))
                    self.maxit += 1
                    return
            else:
                self.rates.append((float(tmp[0]), max_steps))
                self.maxit += 1
                return

    def __iter__(self):
        return self

    def __next__(self):
        if self.it < self.maxit:
            self.it += 1
            return self.rates[self.it - 1]
        else:
            raise StopIteration