__all__ = [
'ETARegistry',
'Datum',
'BaseETA',
'AverageETA',
'EMA_ETA',
]
import datetime
from abc import ABC, abstractmethod
from typing import NamedTuple
from ...bases.configs import Config
from ...utils import EMA
from ..registries import RunnerRegistry
[docs]
class ETARegistry(RunnerRegistry):
pass
[docs]
class Datum(NamedTuple):
x: int
t: datetime.datetime
[docs]
class BaseETA(ABC):
[docs]
def __init__(self, start: int, end: int) -> None:
self._start = self._datum(start)
self._end = end
def _datum(self, x: int) -> Datum:
t = datetime.datetime.now()
return Datum(x, t)
[docs]
@abstractmethod
def pace(self, datum: Datum) -> float:
pass
def __call__(self, x: int) -> float:
datum = self._datum(x)
pace = self.pace(datum)
return pace * (self._end - x) / 1000
[docs]
@ETARegistry.register_()
class AverageETA(BaseETA):
[docs]
def pace(self, datum: Datum) -> float:
t = datum.t - self._start.t
x = datum.x - self._start.x
return t.total_seconds() * 1000 / x
[docs]
@ETARegistry.register_()
class EMA_ETA(AverageETA): # noqa: N801 pylint: disable=invalid-name
[docs]
def __init__(self, *args, ema: Config, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._ema = EMA(**ema)
self._pace: float | None = None
[docs]
def pace(self, datum: Datum) -> float:
pace = super().pace(datum)
pace = self._ema(self._pace, pace)
self._pace = pace
return pace