Source code for finstmt.forecast.main

import dataclasses
import operator
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union

import matplotlib.pyplot as plt
import pandas as pd
from typing_extensions import Self

from finstmt.exc import ForecastNotFitException, ForecastNotPredictedException
from finstmt.forecast.config import ForecastConfig, ForecastItemConfig
from finstmt.forecast.models.chooser import get_model
from finstmt.forecast.models.manual import ManualForecastModel
from finstmt.items.config import ItemConfig

T = TypeVar("T")


[docs]@dataclass class Forecast: """ The main class to represent a forecast of an individual item. """ orig_series: pd.Series config: ForecastConfig item_config: ForecastItemConfig base_config: ItemConfig pct_of_series: Optional[pd.Series] = None pct_of_config: Optional[ItemConfig] = None def __post_init__(self): self.model = get_model(self.config, self.item_config, self.base_config)
[docs] def fit(self): self.model.fit(self.series)
[docs] def predict(self) -> pd.Series: if not self.model.has_been_fit: raise ForecastNotFitException("call .fit before .predict") return self.model.predict()
[docs] def plot( self, ax: Optional[plt.Axes] = None, figsize: Tuple[int, int] = (12, 5) ) -> plt.Figure: if not self.model.has_prediction: raise ForecastNotPredictedException("call .predict before .plot") return self.model.plot(ax=ax, figsize=figsize, title=self.name)
[docs] def to_manual( self, use_levels: bool = False, adjustments: Optional[Union[Sequence[float], Dict[int, float]]] = None, replacements: Optional[Union[Sequence[float], Dict[int, float]]] = None, ): if not self.model.has_prediction: raise ForecastNotPredictedException( "call .fit then .predict before .to_manual" ) if use_levels: values = self.result.values config_key = "levels" reset_key = "growth" else: # Growth values = self.result.pct_change().values # Fill in first growth values[0] = (self.result.iloc[0] - self.series.iloc[-1]) / self.series.iloc[ -1 ] config_key = "growth" reset_key = "levels" self.item_config.method = "manual" if adjustments is not None: if not isinstance(adjustments, dict) and len(adjustments) != len(values): raise ValueError( f"must pass equal length adjustments as number of periods. " f"Got {len(adjustments)} adjustments for {len(values)} periods" ) adjustments = _adjust_to_dict(adjustments) for i, adj in adjustments.items(): values[i] += adj if replacements is not None: if not isinstance(replacements, dict) and len(replacements) != len(values): raise ValueError( f"must pass equal length adjustments as number of periods. " f"Got {len(replacements)} adjustments for {len(values)} periods" ) replacements = _replace_to_dict(replacements) for i, replace in replacements.items(): values[i] = replace self.item_config.manual_forecasts[config_key] = list(values) self.item_config.manual_forecasts[reset_key] = [] self.model = ManualForecastModel( self.config, self.item_config, self.base_config ) self.model.fit(self.series) self.model.predict()
@property def series(self) -> pd.Series: if self.pct_of_series is None: return self.orig_series else: return self.orig_series / self.pct_of_series @property def result(self) -> pd.Series: if not self.model.has_prediction: raise ForecastNotPredictedException( "call .fit then .predict before .result" ) return self.model.result @property def name(self) -> str: if self.pct_of_config is None: return self.base_config.display_name # Percentage of series return f"{self.base_config.display_name} % {self.pct_of_config.display_name}"
[docs] def copy(self, **updates) -> Self: return dataclasses.replace(self, **updates)
def __round__(self, n: Optional[int] = None) -> "Forecast": return _apply_operation_to_forecast(self, n, round) # type: ignore[arg-type] def __add__(self, other: T) -> "Forecast": return _apply_operation_to_forecast(self, other, operator.add) def __sub__(self, other: T) -> "Forecast": return _apply_operation_to_forecast(self, other, operator.sub) def __mul__(self, other: T) -> "Forecast": return _apply_operation_to_forecast(self, other, operator.mul) def __truediv__(self, other: T) -> "Forecast": return _apply_operation_to_forecast(self, other, operator.truediv)
def _apply_operation_to_forecast( forecast: Forecast, other: T, func: Callable[[Any, T], Any], ) -> Forecast: updates: Dict[str, Any] = {} updates["orig_series"] = func( forecast.orig_series, _get_attr_if_needed(other, "orig_series") ) if forecast.pct_of_series is not None: updates["pct_of_series"] = func( forecast.pct_of_series, _get_attr_if_needed(other, "pct_of_series") ) if forecast.pct_of_config is not None: updates["pct_of_config"] = func( forecast.pct_of_config, _get_attr_if_needed(other, "pct_of_config") ) updates["item_config"] = func( forecast.item_config, _get_attr_if_needed(other, "item_config") ) updates["base_config"] = func( forecast.base_config, _get_attr_if_needed(other, "base_config") ) return forecast.copy(**updates) def _get_attr_if_needed(other: Any, attr: str) -> Any: if isinstance(other, Forecast): return getattr(other, attr) else: return other def _adjust_to_dict( seq_or_dict: Union[Sequence[float], Dict[int, float]] ) -> Dict[int, float]: if isinstance(seq_or_dict, dict): return seq_or_dict # If adjustment is 0 or None, skip the entry, otherwise add to the dict adjust_dict = {i: val for i, val in enumerate(seq_or_dict) if val} return adjust_dict def _replace_to_dict( seq_or_dict: Union[Sequence[float], Dict[int, float]] ) -> Dict[int, float]: if isinstance(seq_or_dict, dict): return seq_or_dict replace_dict = {i: val for i, val in enumerate(seq_or_dict)} return replace_dict