Source code for finstmt.forecast.models.manual

from typing import Optional

import pandas as pd

from finstmt.exc import ImproperManualForecastException
from finstmt.forecast.config import ForecastConfig, ForecastItemConfig
from finstmt.forecast.models.base import ForecastModel
from finstmt.items.config import ItemConfig


[docs]class ManualForecastModel(ForecastModel): recent: Optional[float] = None
[docs] def __init__( self, config: ForecastConfig, item_config: ForecastItemConfig, base_config: ItemConfig, ): super().__init__(config, item_config, base_config) self._set_growths_levels() self._validate()
def _validate(self): if not self.growths and not self.levels: raise ImproperManualForecastException( "must provide either growth or levels for manual forecast" ) if self.growths and self.levels: raise ImproperManualForecastException( "must only provide one of growth or levels for manual forecast" ) forecast_length_error_str = ( f"were provided for {self.config.periods} forecast periods" ) if self.growths: if len(self.growths) != self.config.periods: raise ImproperManualForecastException( f"{len(self.growths)} growth rates {forecast_length_error_str}" ) else: if len(self.levels) != self.config.periods: raise ImproperManualForecastException( f"{len(self.levels)} levels {forecast_length_error_str}" ) def _set_growths_levels(self): self.growths = self.item_config.manual_forecasts["growth"] self.levels = self.item_config.manual_forecasts["levels"]
[docs] def fit(self, series: pd.Series): self.recent = series.iloc[-1] super().fit(series)
[docs] def predict(self) -> pd.Series: if self.growths: values = [] last_value = self.recent for growth in self.growths: next_value = last_value * (1 + growth) values.append(next_value) last_value = next_value else: values = self.levels self.result = pd.Series(values, index=self._future_date_range) self.result_df = pd.DataFrame( pd.concat([self.orig_series, self.result]), columns=["mean"] ) super().predict() return self.result