Source code for finstmt.forecast.models.prophet

import logging
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import pandas as pd

from finstmt.forecast.config import ForecastConfig, ForecastItemConfig
from finstmt.forecast.dataframe import add_cap_and_floor_to_df
from finstmt.forecast.models.base import ForecastModel
from finstmt.items.config import ItemConfig


[docs]class FBProphetModel(ForecastModel):
[docs] def __init__( self, config: ForecastConfig, item_config: ForecastItemConfig, base_config: ItemConfig, ): super().__init__(config, item_config, base_config) Prophet = _try_import_prophet() all_kwargs = {} if config.freq.lower() == "y": all_kwargs["yearly_seasonality"] = False all_kwargs.update(item_config.prophet_kwargs) all_kwargs.update(config.prophet_kwargs) model = Prophet(**all_kwargs) self.model = model
[docs] def fit(self, series: pd.Series): self.model.fit(self._df_for_fit(series)) super().fit(series)
[docs] def predict(self) -> pd.Series: future = self.model.make_future_dataframe(**self.config.make_future_df_kwargs) add_cap_and_floor_to_df(future, self.item_config.cap, self.item_config.floor) forecast = self.model.predict(future) self.result_df = forecast result = forecast[["ds", "yhat"]].set_index("ds")["yhat"] result = result[result.index > self.last_historical_period] self.result = result super().predict() return result
[docs] def plot( self, ax: Optional[plt.Axes] = None, figsize: Tuple[int, int] = (12, 5), xlabel: Optional[str] = None, ylabel: Optional[str] = None, title: Optional[str] = None, ) -> plt.Figure: if xlabel is None: xlabel = "Time" if title is None: title = self.base_config.display_name fig = self.model.plot( self.result_df, ax=ax, figsize=figsize, xlabel=xlabel, ylabel=ylabel ) if ax is not None and title: ax.set_title(title) elif title: _set_title_on_axes(fig, title) plt.close() return fig
def _df_for_fit(self, series: pd.Series) -> pd.DataFrame: df = pd.DataFrame(series).reset_index() df.columns = ["ds", "y"] add_cap_and_floor_to_df(df, self.item_config.cap, self.item_config.floor) return df
def _try_import_prophet(): try: from prophet import Prophet except ImportError: raise ImportError( "need to install prophet to use forecasting functionality with method auto. " "see https://facebook.github.io/prophet/docs/installation.html" ) # Suppress excessive logging from prophet prophet_logger = logging.getLogger("prophet") prophet_logger.setLevel(logging.WARN) return Prophet def _set_title_on_axes(fig: plt.Figure, title: str): for ax in fig.axes: ax.set_title(title)