__all__ = [
'BaseScheduler',
'WarmupScheduler',
'EarlyStopScheduler',
'DeferScheduler',
'DecayScheduler',
'StepScheduler',
'CosineAnnealingScheduler',
'ComposedScheduler',
'ChainedScheduler',
'SequentialScheduler',
]
import bisect
import math
from typing import Iterable, cast
import torch
from torch import nn
from ...bases.configs import Config
from ...bases.registries import BuildPreHookMixin, Item, RegistryMeta
from .registries import SchedulerRegistry
[docs]
@SchedulerRegistry.register_()
class BaseScheduler(nn.Module):
"""Base class for schedulers.
Under most cases, schedulers are used as a variable loss weight.
Schedulers are functions of `steps`, which could mean iterations or
epochs.
Users could increment `steps` by calling `step`, or directly set the
`steps` property.
Call the scheduler to get the value for the current step.
.. note:: `steps` starts from 1, so `step` should be called after the
first step.
The value of this scheduler is always the `gain`:
>>> base_scheduler = BaseScheduler(gain=5)
>>> base_scheduler()
5.0
>>> base_scheduler.step()
>>> base_scheduler()
5.0
"""
[docs]
def __init__(self, *args, gain: float = 1.0, **kwargs) -> None:
"""Initialize.
Args:
gain: multiplier to the scheduler value.
"""
super().__init__(*args, **kwargs)
self._gain = gain
self.register_forward_hook(lambda m, i, o: self._scale(o))
self.register_buffer('_steps', torch.tensor(1))
@property
def gain(self) -> float:
return self._gain
@property
def steps(self) -> int:
return cast(int, self._steps.item())
@steps.setter
def steps(self, value: int) -> None:
self._steps = torch.tensor(value)
[docs]
def step(self) -> None:
self._steps += 1
def _scale(self, output: float) -> float:
return output * self.gain
[docs]
def forward(self) -> float:
"""Compute the current schedule weight.
Returns:
The scheduler's value for the current step, before multiplying
`gain`.
Since `gain` is handled by this base class, it is usually adequate for
`forward` to return a percentage value in :math:`[0, 1]`.
"""
return 1.0
[docs]
@SchedulerRegistry.register_()
class WarmupScheduler(BaseScheduler):
"""Warmup scheduler.
The value will linearly increase from 0 to 1.
At step ``end``, the value is 1.
>>> warmup = WarmupScheduler(end=5)
>>> for _ in range(7):
... print(warmup())
... warmup.step()
0.2
0.4
0.6
0.8
1.0
1.0
1.0
"""
[docs]
def __init__(self, *args, start: int = 0, end: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._start = start
self._end = end
[docs]
def forward(self) -> float:
if self.steps <= self._start:
return 0.0
if self.steps >= self._end:
return 1.0
return (self.steps - self._start) / (self._end - self._start)
[docs]
@SchedulerRegistry.register_()
class DecayScheduler(WarmupScheduler):
"""Decay scheduler.
Before or at ``start``, the value is 1.
After or at ``end``, the value is 0.
In between, the value is interpolated.
>>> decay = DecayScheduler(start=2, end=7)
>>> for _ in range(8):
... print(round(decay(), 1))
... decay.step()
1.0
1.0
0.8
0.6
0.4
0.2
0.0
0.0
"""
[docs]
def forward(self) -> float:
return 1 - super().forward()
class JumpMixin(WarmupScheduler):
def __init__(self, *args, at: int, **kwargs) -> None:
super().__init__(*args, start=at - 1, end=at, **kwargs)
[docs]
@SchedulerRegistry.register_()
class DeferScheduler(JumpMixin, WarmupScheduler):
[docs]
def __init__(self, *args, to: int, **kwargs) -> None:
super().__init__(*args, at=to, **kwargs)
[docs]
@SchedulerRegistry.register_()
class EarlyStopScheduler(JumpMixin, DecayScheduler):
"""Early stop.
At some point, the value drops to 0 from 1.
>>> early_stop = EarlyStopScheduler(at=3)
>>> for _ in range(5):
... print(early_stop())
... early_stop.step()
1.0
1.0
0.0
0.0
0.0
"""
[docs]
@SchedulerRegistry.register_()
class StepScheduler(BaseScheduler):
"""Step scheduler.
The value is multiplied by :math:`gamma` at every milestone:
>>> step = StepScheduler(milestones=[3, 4], gamma=0.1)
>>> for _ in range(5):
... print(round(step(), 2))
... step.step()
1.0
1.0
0.1
0.01
0.01
"""
[docs]
def __init__(
self,
*args,
milestones: Iterable[int],
gamma: float,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._milestones = sorted(milestones)
self._gamma = gamma
[docs]
def forward(self) -> float:
return self._gamma**bisect.bisect(self._milestones, self.steps)
[docs]
@SchedulerRegistry.register_()
class CosineAnnealingScheduler(BaseScheduler):
"""Cosine annealing scheduler.
The value anneals as the cosine function is defined.
The first step starts with 1.
After ``duration`` steps, the value becomes 0.
The best practice is to set ``duration`` to the total number of steps.
>>> cosine = CosineAnnealingScheduler(duration=5)
>>> for _ in range(6):
... print(round(cosine(), 6))
... cosine.step()
1.0
0.904508
0.654508
0.345492
0.095492
0.0
"""
[docs]
def __init__(
self,
*args,
duration: int,
min_: float = 0.,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._duration = duration
self._min = min_
[docs]
def forward(self) -> float:
steps = self.steps - 1
if steps >= self._duration:
return 0
return (
self._min + (1 - self._min) *
(1 + math.cos(math.pi * steps / self._duration)) / 2
)
[docs]
class ComposedScheduler(BuildPreHookMixin, BaseScheduler):
[docs]
def __init__(
self,
*args,
schedulers: Iterable[BaseScheduler],
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._schedulers = tuple(schedulers)
[docs]
@classmethod
def build_pre_hook(
cls,
config: Config,
registry: RegistryMeta,
item: Item,
) -> Config:
config = super().build_pre_hook(config, registry, item)
config.schedulers = [
SchedulerRegistry.build_or_return(c) for c in config.schedulers
]
return config
[docs]
@SchedulerRegistry.register_()
class ChainedScheduler(ComposedScheduler):
"""Chained scheduler.
Schedulers are chained in an multiplicative manner:
>>> warmup = WarmupScheduler(end=5, gain=10)
>>> step = StepScheduler(milestones=[3, 4], gamma=0.1)
>>> chained = ChainedScheduler(schedulers=[warmup, step])
>>> for _ in range(5):
... print(round(chained(), 6))
... chained.step()
2.0
4.0
0.6
0.08
0.1
"""
[docs]
def forward(self) -> float:
return math.prod(scheduler() for scheduler in self._schedulers)
[docs]
def step(self) -> None:
super().step()
for scheduler in self._schedulers:
scheduler.step()
[docs]
@SchedulerRegistry.register_()
class SequentialScheduler(ComposedScheduler):
[docs]
def __init__(
self,
*args,
milestones: Iterable[int],
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._milestones = sorted(milestones)
@property
def scheduler(self) -> BaseScheduler:
i = bisect.bisect(self._milestones, self.steps)
return self._schedulers[i]
[docs]
def forward(self) -> float:
return self.scheduler()
[docs]
def step(self) -> None:
self.scheduler.step()
super().step()