Source code for finstmt.combined.combinator

import operator
from typing import TYPE_CHECKING, Any, Callable, Dict, Protocol, Tuple, TypeVar

from typing_extensions import TypeGuard

if TYPE_CHECKING:
    from finstmt.bs.main import BalanceSheets
    from finstmt.combined.statements import FinancialStatements
    from finstmt.forecast.main import Forecast
    from finstmt.forecast.statements import ForecastedFinancialStatements
    from finstmt.inc.main import IncomeStatements

T = TypeVar("T")
StatementT = TypeVar("StatementT", bound="FinancialStatements")


[docs]class StatementsCombinator(Protocol[StatementT]):
[docs] def add(self, statement: StatementT, other: Any) -> StatementT: ...
[docs] def subtract(self, statement: StatementT, other: Any) -> StatementT: ...
[docs] def multiply(self, statement: StatementT, other: Any) -> StatementT: ...
[docs] def divide(self, statement: StatementT, other: Any) -> StatementT: ...
[docs]class FinancialStatementsCombinator(StatementsCombinator["FinancialStatements"]):
[docs] def add( self, statement: "FinancialStatements", other: Any ) -> "FinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.add ) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets )
[docs] def subtract( self, statement: "FinancialStatements", other: Any ) -> "FinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.sub ) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets )
[docs] def multiply( self, statement: "FinancialStatements", other: Any ) -> "FinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.mul ) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets )
[docs] def divide( self, statement: "FinancialStatements", other: Any ) -> "FinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.truediv ) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets )
[docs]class ForecastedFinancialStatementsCombinator( StatementsCombinator["ForecastedFinancialStatements"] ):
[docs] def add( self, statement: "ForecastedFinancialStatements", other: Any ) -> "ForecastedFinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.add ) forecasts = _apply_to_forecasts(statement.forecasts, other, operator.add) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets, forecasts=forecasts, )
[docs] def subtract( self, statement: "ForecastedFinancialStatements", other: Any ) -> "ForecastedFinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.sub ) forecasts = _apply_to_forecasts(statement.forecasts, other, operator.sub) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets, forecasts=forecasts, )
[docs] def multiply( self, statement: "ForecastedFinancialStatements", other: Any ) -> "ForecastedFinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.mul ) forecasts = _apply_to_forecasts(statement.forecasts, other, operator.mul) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets, forecasts=forecasts, )
[docs] def divide( self, statement: "ForecastedFinancialStatements", other: Any ) -> "ForecastedFinancialStatements": income_statements, balance_sheets = _apply_to_child_statements( statement, other, operator.truediv ) forecasts = _apply_to_forecasts(statement.forecasts, other, operator.truediv) return statement.copy( income_statements=income_statements, balance_sheets=balance_sheets, forecasts=forecasts, )
def _apply_to_child_statements( statements: "FinancialStatements", other: Any, func: Callable[[Any, Any], Any], ) -> Tuple["IncomeStatements", "BalanceSheets"]: from finstmt import FinancialStatements if isinstance(other, (float, int)): new_inc = func(statements.income_statements, other) new_bs = func(statements.balance_sheets, other) elif isinstance(other, FinancialStatements): new_inc = func(statements.income_statements, other.income_statements) new_bs = func(statements.balance_sheets, other.balance_sheets) else: raise NotImplementedError( f"cannot {func.__name__} type {type(statements)} to type {type(other)}" ) return new_inc, new_bs def _apply_to_forecasts( forecasts: Dict[str, "Forecast"], other: Any, func: Callable[["Forecast", Any], "Forecast"], ) -> Dict[str, "Forecast"]: if isinstance(other, (float, int)): return {k: func(v, other) for k, v in forecasts.items()} elif _is_forecasted_financial_statements(other): return {k: func(v, other.forecasts[k]) for k, v in forecasts.items()} else: raise NotImplementedError( f"cannot {func.__name__} type {type(forecasts)} to type {type(other)}" ) def _is_forecasted_financial_statements( obj: Any, ) -> TypeGuard["ForecastedFinancialStatements"]: return hasattr(obj, "forecasts")