# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from bisect import bisect_right
from typing import Iterable as Iter
from .lr_scheduler import LRScheduler
from .optimizer import Optimizer
[文档]class MultiStepLR(LRScheduler):
    r"""Decays the learning rate of each parameter group by gamma once the
        number of epoch reaches one of the milestones.
    Args:
        optimizer: wrapped optimizer.
        milestones: list of epoch indices which should be increasing.
        gamma: multiplicative factor of learning rate decay. Default: 0.1
        current_epoch: the index of current epoch. Default: -1
    """
    def __init__(
        self,
        optimizer: Optimizer,
        milestones: Iter[int],
        gamma: float = 0.1,
        current_epoch: int = -1,
    ):
        if not list(milestones) == sorted(milestones):
            raise ValueError(
                "Milestones should be a list of increasing integers. Got {}".format(
                    milestones
                )
            )
        self.milestones = milestones
        self.gamma = gamma
        super().__init__(optimizer, current_epoch)
[文档]    def state_dict(self):
        r"""Returns the state of the scheduler as a :class:`dict`.
            It contains an entry for every variable in self.__dict__ which
            is not the optimizer.
        """
        return {
            key: value
            for key, value in self.__dict__.items()
            if key in ["milestones", "gamma", "current_epoch"]
        } 
[文档]    def load_state_dict(self, state_dict):
        r"""Loads the schedulers state.
        Args:
          state_dict: scheduler state.
        """
        tmp_dict = {}
        for key in ["milestones", "gamma", "current_epoch"]:
            if not key in state_dict.keys():
                raise KeyError(
                    "key '{}'' is not specified in "
                    "state_dict when loading state dict".format(key)
                )
            tmp_dict[key] = state_dict[key]
        self.__dict__.update(tmp_dict) 
[文档]    def get_lr(self):
        return [
            base_lr * self.gamma ** bisect_right(self.milestones, self.current_epoch)
            for base_lr in self.base_lrs
        ]