• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 10822158814

12 Sep 2024 12:59AM UTC coverage: 83.288% (+0.02%) from 83.266%
10822158814

Pull #509

github

web-flow
Merge 4885f624e into fdd3d322a
Pull Request #509: Apply line-length=120 to black format

437 of 618 new or added lines in 157 files covered. (70.71%)

85 existing lines in 39 files now uncovered.

11263 of 13523 relevant lines covered (83.29%)

4.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/pypots/optim/lr_scheduler/lambda_lrs.py
1
"""
6✔
2
Lambda learning rate scheduler.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8
from typing import Callable, Union
6✔
9

10
from .base import LRScheduler, logger
6✔
11

12

13
class LambdaLR(LRScheduler):
6✔
14
    """Sets the learning rate of each parameter group to the initial lr times a given function.
6✔
15
    When last_epoch=-1, sets initial lr as lr.
16

17
    Parameters
18
    ----------
19
    lr_lambda: Callable or list,
20
        A function which computes a multiplicative factor given an integer parameter epoch, or a list of such
21
        functions, one for each group in optimizer.param_groups.
22

23
    last_epoch: int,
24
        The index of last epoch. Default: -1.
25

26
    verbose: bool,
27
        If ``True``, prints a message to stdout for each update. Default: ``False``.
28

29
    Notes
30
    -----
31
    This class works the same with ``torch.optim.lr_scheduler.LambdaLR``.
32
    The only difference that is also why we implement them is that you don't have to pass according optimizers
33
    into them immediately while initializing them.
34

35
    Example
36
    -------
37
    >>> lambda1 = lambda epoch: epoch // 30
38
    >>> scheduler = LambdaLR(lr_lambda=lambda1)
39
    >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler)
40

41
    """
42

43
    def __init__(
6✔
44
        self,
45
        lr_lambda: Union[Callable, list],
46
        last_epoch: int = -1,
47
        verbose: bool = False,
48
    ):
49
        super().__init__(last_epoch, verbose)
6✔
50
        self.lr_lambda = lr_lambda
6✔
51
        self.lr_lambdas = None
6✔
52

53
    def init_scheduler(self, optimizer):
6✔
54
        if not isinstance(self.lr_lambda, list) and not isinstance(self.lr_lambda, tuple):
6✔
55
            self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups)
6✔
56
        else:
57
            if len(self.lr_lambda) != len(optimizer.param_groups):
×
58
                raise ValueError(
×
59
                    "Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(self.lr_lambda))
60
                )
61
            self.lr_lambdas = list(self.lr_lambda)
×
62

63
        super().init_scheduler(optimizer)
6✔
64

65
    def get_lr(self):
6✔
66
        if not self._get_lr_called_within_step:
6✔
NEW
67
            logger.warning("⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.")
×
68

69
        return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc