import itertools
import timeit
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple
import numpy as np
import pandas as pd
from scipy.optimize import OptimizeResult, minimize
from sympy import Eq, Expr, IndexedBase, solve, sympify
from sympy.core.numbers import NaN
from finstmt.combined.statements import FinancialStatements
from finstmt.config_manage.data import _key_pct_of_key
from finstmt.config_manage.statements import StatementsConfigManager
from finstmt.exc import (
BalanceSheetNotBalancedException,
InvalidBalancePlugsException,
InvalidForecastEquationException,
MissingDataException,
)
from finstmt.forecast.main import Forecast
from finstmt.forecast.statements import ForecastedFinancialStatements
from finstmt.items.config import ItemConfig
from finstmt.logger import logger
from finstmt.resolver.base import ResolverBase
# TODO [#46]: clean up ForecastResolver
#
# `ForecastResolver` and associated logic is messy after reworking it multiple times.
# Need to remove unneeded code and restructure more logic into classes. `PlugResult`
# could handle more operations with the plugs, and the math could be more separated
# from the finance logic.
from finstmt.resolver.solve import (
PLUG_SCALE,
_get_indexed_symbols,
_solve_eqs_with_plug_solutions,
_symbolic_to_matrix,
_x_arr_to_plug_solutions,
solve_equations,
sympy_dict_to_results_dict,
)
[docs]class ForecastResolver(ResolverBase):
[docs] def __init__(
self,
stmts: "FinancialStatements",
forecast_dict: Dict[str, Forecast],
results: Dict[str, pd.Series],
bs_diff_max: float,
timeout: float,
balance: bool = True,
):
self.forecast_dict = forecast_dict
self.results = results
self.bs_diff_max = bs_diff_max
self.timeout = timeout
self.balance = balance
if balance:
self.exclude_plugs = True
else:
self.exclude_plugs = False
super().__init__(stmts)
[docs] def resolve_balance_sheet(self):
logger.info("Balancing balance sheet")
solutions_dict = self.subs_dict.copy()
new_solutions = resolve_balance_sheet(
self.plug_x0,
self.solve_eqs,
self.plug_keys,
self.subs_dict,
self.forecast_dates,
self.stmts.config,
self.stmts.config.sympy_namespace,
self.bs_diff_max,
self.stmts.config.balance_groups,
self.timeout,
)
solutions_dict.update(new_solutions)
return solutions_dict
[docs] def to_statements(self) -> ForecastedFinancialStatements:
if self.balance:
solutions_dict = self.resolve_balance_sheet()
else:
if self.solve_eqs:
solutions_dict = solve_equations(self.solve_eqs, self.subs_dict)
else:
solutions_dict = self.subs_dict
new_results = sympy_dict_to_results_dict(
solutions_dict, self.forecast_dates, self.stmts.all_config_items, t_offset=1
)
if self.balance:
# Update forecast dict for plug values
for config in self.plug_configs:
self.forecast_dict[config.key].to_manual(
use_levels=True, replacements=new_results[config.key].values
)
all_results = pd.concat(list(new_results.values()), axis=1).T
inc_df = self.stmts.income_statements.__class__.from_df(
all_results,
self.stmts.income_statements.config.items,
disp_unextracted=False,
)
bs_df = self.stmts.balance_sheets.__class__.from_df(
all_results, self.stmts.balance_sheets.config.items, disp_unextracted=False
)
# type ignore added because for some reason mypy is not picking up structure
# correctly since it is a dataclass
obj = ForecastedFinancialStatements(inc_df, bs_df, forecasts=self.forecast_dict, calculate=False) # type: ignore
return obj
@property
def t_indexed_eqs(self) -> List[Eq]:
config_managers = [
self.stmts.income_statements.config.items,
self.stmts.balance_sheets.config.items,
]
all_eqs = []
for config_manage in config_managers:
for config in config_manage:
lhs = sympify(
config.key + "[t]", locals=self.stmts.config.sympy_namespace
)
if config.expr_str is not None:
rhs = self.stmts.config.expr_for(config.key)
elif (
config.forecast_config.pct_of is not None
and config.forecast_config.make_forecast
):
key_pct_of_key = _key_pct_of_key(
config.key, config.forecast_config.pct_of
)
rhs = sympify(
f"{config.forecast_config.pct_of}[t] * {key_pct_of_key}[t]",
locals=self.stmts.config.sympy_namespace,
)
else:
rhs = lhs
if not rhs == lhs:
eq = Eq(lhs, rhs)
all_eqs.append(eq)
return all_eqs
@property
def all_eqs(self) -> List[Eq]:
t_eqs = self.t_indexed_eqs
out_eqs = []
# Starting from 1 as 0 is last historical period, no need to calculate
for period in range(1, self.num_periods):
this_t_eqs = [eq.subs({self.t: period}) for eq in t_eqs]
out_eqs.extend(this_t_eqs)
all_hardcoded = _x_arr_to_plug_solutions(
self.plug_x0, self.plug_keys, self.stmts.config.sympy_namespace
)
all_hardcoded.update(self.sympy_subs_dict)
new_eqs = _get_equations_reformed_for_needed_solutions(
out_eqs, all_hardcoded, self.stmts.config
)
return new_eqs
@property
def num_periods(self) -> int:
# adding 1 because final existing period will be included as period 0
return list(self.forecast_dict.values())[0].config.periods + 1
@property
def forecast_dates(self) -> pd.DatetimeIndex:
return list(self.results.values())[0].index
@property
def sympy_subs_dict(self) -> Dict[IndexedBase, float]:
nper = self.num_periods
subs_dict = {}
for config in self.stmts.all_config_items:
if config.forecast_config.pct_of:
key = _key_pct_of_key(config.key, config.forecast_config.pct_of)
else:
key = config.key
for period in range(nper):
t_key = f"{key}[{period}]"
lhs = sympify(t_key, locals=self.stmts.config.sympy_namespace)
if period == 0:
# period 0 is last historical period, not forecasted period
try:
value = getattr(self.stmts, key).iloc[-1]
except AttributeError as e:
if "_pct_" in str(e):
# Got a percentage of item, only in forecasted results, skip
continue
else:
raise e
else:
# period 1 or later, forecasted period, get from forecast results
# If it is a plug item, don't get forecasted values
if self.exclude_plugs and config.forecast_config.plug:
continue
try:
series = self.results[key]
except KeyError:
# Must not be a forecasted item, probably calculated item
continue
value = series.iloc[period - 1]
subs_dict[lhs] = value
return subs_dict
@property
def bs_balance_eqs(self) -> List[Eq]:
eqs = []
for balance_set in self.stmts.config.balance_groups:
for period in range(1, self.num_periods):
for combo in itertools.combinations(balance_set, 2):
lhs_key = f"{combo[0]}[{period}]"
lhs = sympify(lhs_key, locals=self.stmts.config.sympy_namespace)
rhs_key = f"{combo[1]}[{period}]"
rhs = sympify(rhs_key, locals=self.stmts.config.sympy_namespace)
eqs.append(Eq(lhs, rhs))
return eqs
@property
def plug_configs(self) -> List[ItemConfig]:
return [
conf for conf in self.stmts.all_config_items if conf.forecast_config.plug
]
@property
def plug_keys(self) -> List[str]:
return [config.key for config in self.plug_configs]
@property
def plug_x0(self) -> np.ndarray:
x_arrs = []
plug_keys = []
for config in self.plug_configs:
x_arrs.append(self.results[config.key].values)
plug_keys.append(config.key)
x0 = np.concatenate(x_arrs) / PLUG_SCALE
return x0
[docs]@dataclass
class PlugResult:
res: Optional[np.ndarray] = None
timeout: float = 180
start_time: Optional[float] = None
fun: Optional[float] = None
met_goal: bool = False
def __post_init__(self):
if self.start_time is None:
self.start_time = timeit.default_timer()
@property
def time_elapsed(self) -> float:
if self.start_time is None:
raise ValueError("Must instantiate PlugResult to get time_elapsed")
return timeit.default_timer() - self.start_time
@property
def is_timed_out(self) -> bool:
return self.time_elapsed > self.timeout
[docs]def resolve_balance_sheet(
x0: np.ndarray,
eqs: List[Eq],
plug_keys: Sequence[str],
subs_dict: Dict[IndexedBase, float],
forecast_dates: pd.DatetimeIndex,
config: StatementsConfigManager,
sympy_namespace: Dict[str, IndexedBase],
bs_diff_max: float,
balance_groups: List[Set[str]],
timeout: float,
) -> Dict[IndexedBase, float]:
plug_solutions = _x_arr_to_plug_solutions(x0, plug_keys, sympy_namespace)
all_to_solve: Dict[IndexedBase, Expr] = {}
for eq in eqs:
expr = eq.rhs - eq.lhs
if expr == NaN():
raise InvalidForecastEquationException(
f"got NaN forecast equation. LHS: {eq.lhs}, RHS: {eq.rhs}"
)
if eq.lhs in all_to_solve:
raise InvalidForecastEquationException(
f"got multiple equations to solve for {eq.lhs}. Already had {all_to_solve[eq.lhs]}, now got {expr}"
)
all_to_solve[eq.lhs] = expr
for sol_dict in [subs_dict, plug_solutions]:
# Plug solutions second here so that they are at end of array
for lhs, rhs in sol_dict.items():
expr = rhs - lhs
if expr == NaN():
raise MissingDataException(
f"got NaN for {lhs} but that is needed for resolving the forecast"
)
if lhs in all_to_solve:
existing_value = all_to_solve[lhs]
if isinstance(existing_value, float):
had_message = f"forecast/plug value of {existing_value}"
else:
had_message = f"equation of {existing_value}"
raise InvalidForecastEquationException(
f"got forecast/plug value for {lhs} but already had an existing {had_message}, now got {expr}"
)
all_to_solve[lhs] = expr
to_solve_for = list(all_to_solve.keys())
solve_exprs = list(all_to_solve.values())
_check_for_invalid_system_of_equations(
eqs, subs_dict, plug_solutions, to_solve_for, solve_exprs
)
# TODO: Is Symbol or IndexedBase the correct type here?
eq_arrs = _symbolic_to_matrix(solve_exprs, to_solve_for) # type: ignore[arg-type]
# Get better initial x0 by adding to appropriate plug
_adjust_x0_to_initial_balance_guess(
x0, plug_keys, eq_arrs, forecast_dates, to_solve_for, config, balance_groups
)
result = PlugResult(timeout=timeout)
res: Optional[OptimizeResult] = None
try:
res = minimize(
_resolve_balance_sheet_check_diff,
x0,
args=(
eq_arrs,
forecast_dates,
to_solve_for,
bs_diff_max,
balance_groups,
result,
),
bounds=[(0, None) for _ in range(len(x0))], # all positive
method="TNC",
options=dict(
maxCGit=0,
maxfun=1000000000,
),
)
except (BalanceSheetBalancedException, BalanceSheetNotBalancedException):
pass
if not result.met_goal:
if result.fun is None or result.res is None:
# Mainly for mypy purposes
raise BalanceSheetNotBalancedException(
"Unexpected balancing error. Did not evaluate the balancing function even once"
)
plug_solutions = _x_arr_to_plug_solutions(
result.res, plug_keys, sympy_namespace
)
avg_error = (result.fun**2 / len(result.res)) ** 0.5
message = (
f"final solution {plug_solutions} still could not meet max difference of "
f"${bs_diff_max:,.0f} within timeout of {result.timeout}s. "
f"Average difference was ${avg_error:,.0f}.\nIf the make_forecast or plug "
f"configuration for any items were changed, ensure that changes in {plug_keys} can flow through "
f"to Total Assets and Total Liabilities and Equity. For example, if make_forecast=True for Total Debt "
f"and make_forecast=False for ST Debt, then using LT debt as a plug will not work as ST debt will "
f"go down when LT debt goes up.\nOtherwise, consider "
f"passing to .forecast a timeout greater than {result.timeout}, "
f"a bs_diff_max at a value greater than {avg_error:,.0f}, or pass "
f"balance=False to skip balancing entirely."
)
raise BalanceSheetNotBalancedException(message)
else:
logger.info(f"Balanced in {result.time_elapsed:.1f}s")
if result.res is None:
raise BalanceSheetNotBalancedException(
"Unexpected balancing error. No result found even though met_goal was True"
)
plug_solutions = _x_arr_to_plug_solutions(result.res, plug_keys, sympy_namespace)
solutions_dict = _solve_eqs_with_plug_solutions(
eqs, plug_solutions, subs_dict, forecast_dates, config.items
)
return solutions_dict
def _resolve_balance_sheet_check_diff(
x: np.ndarray,
eq_arrs: Tuple[np.ndarray, np.ndarray],
forecast_dates: pd.DatetimeIndex,
solve_for: Sequence[IndexedBase],
bs_diff_max: float,
balance_groups: List[Set[str]],
res: PlugResult,
):
if res.is_timed_out:
raise BalanceSheetNotBalancedException
sol_arr = _eq_arrs_and_x_to_sol_arr(x, eq_arrs)
norms: List[float] = []
for balance_group in balance_groups:
balance_arrs = _balance_group_to_balance_arrs(
balance_group, sol_arr, solve_for, len(forecast_dates)
)
norm = 0.0
for arr_pair in itertools.combinations(balance_arrs, 2):
diff = abs(arr_pair[0] - arr_pair[1]).astype(float)
pair_norm = np.linalg.norm(diff)
norm += pair_norm # type: ignore[assignment]
norms.append(norm)
desired_norm = np.linalg.norm([bs_diff_max] * len(forecast_dates))
full_norm = sum(norms)
res.res = x
res.fun = full_norm
logger.debug(f"{res.time_elapsed:.1f}: x: {x * PLUG_SCALE}, norm: {full_norm}")
if all([norm <= desired_norm for norm in norms]):
res.met_goal = True
raise BalanceSheetBalancedException(x)
return full_norm
def _balance_group_to_balance_arrs(
balance_group: Set[str],
sol_arr: np.ndarray,
solve_for: Sequence[IndexedBase],
num_periods: int,
) -> List[np.ndarray]:
balance_arrs: List[np.ndarray] = [
np.zeros(num_periods) for _ in range(len(balance_group))
]
balance_list = list(balance_group)
for value, var in zip(sol_arr, solve_for):
key = str(var.base) # type: ignore[attr-defined]
if key in balance_list:
arr_idx = balance_list.index(key) # type: ignore[attr-defined]
t = int(var.indices[0]) - 1 # type: ignore[attr-defined]
if t >= 0:
balance_arrs[arr_idx][t] = value
return balance_arrs
def _eq_arrs_and_x_to_sol_arr(
x: np.ndarray, eq_arrs: Tuple[np.ndarray, np.ndarray]
) -> np.ndarray:
A_arr, b_arr = eq_arrs
b_arr[-len(x) :] = -x * PLUG_SCALE # plug solutions with new X values
sol_arr = np.linalg.solve(A_arr, b_arr)
return sol_arr
def _adjust_x0_to_initial_balance_guess(
x0: np.ndarray,
plug_keys: Sequence[str],
eq_arrs: Tuple[np.ndarray, np.ndarray],
forecast_dates: pd.DatetimeIndex,
solve_for: Sequence[IndexedBase],
config: StatementsConfigManager,
balance_groups: List[Set[str]],
):
sol_arr = _eq_arrs_and_x_to_sol_arr(x0, eq_arrs)
n_periods = len(forecast_dates)
for balance_group in balance_groups:
balance_arrs = _balance_group_to_balance_arrs(
balance_group, sol_arr, solve_for, n_periods
)
# Get plug which corresponds to each balance item e.g. find cash for assets
balance_group_plug_keys: List[Optional[str]] = []
for balance_item in balance_group:
possible_plug_keys = config.item_determinant_keys(balance_item)
plug_key: Optional[str] = None
for key in possible_plug_keys:
if config.get(key).forecast_config.plug:
plug_key = key # e.g. cash
break
balance_group_plug_keys.append(plug_key)
bg_with_arrs = list(zip(balance_group, balance_arrs, balance_group_plug_keys))
for bg_arr1, bg_arr2 in itertools.combinations(bg_with_arrs, 2):
bg1, arr1, plug1 = bg_arr1
bg2, arr2, plug2 = bg_arr2
# e.g. assets - liabilities and equity
diff = (arr1 - arr2).astype(float)
for i, d in enumerate(diff):
# Handle periods one by one
if d > 0:
# e.g. first period asset is greater than first period liabilities and equity
# Therefore adjust by adding to liabilities and equity
adjust_side = bg2
plug_key = plug2
else:
# e.g. first period asset is less than first period liabilities and equity
# Therefore adjust by adding to assets
adjust_side = bg1
plug_key = plug1
if plug_key is None:
normally_calculated_but_not_keys: List[str] = []
for item in config.items:
if (
item.expr_str is not None
and item.forecast_config.make_forecast == True
):
normally_calculated_but_not_keys.append(item.key)
message = (
f"Trying to balance {adjust_side} but no plug affects it. One of the following "
f"items must have forecast_config.plug = True so that it can be balanced: "
f"{config.item_determinant_keys(adjust_side)}. Current plugs: {plug_keys}. "
)
if normally_calculated_but_not_keys:
message += (
f"If you expected one of the plugs to affect {adjust_side} but it is not listed "
f"in the possible items, it may be that make_forecast has been set to True for an "
f"item which would normally be calculated from your plug, but as make_forecast is True "
f"it forecasts it rather than calculating and it cannot flow through. Possible items "
f"which are normally calculated but are instead being forecasted due to the config: "
f"{normally_calculated_but_not_keys}. Either change the plug to be that same item "
f"which is normally calculated but instead is forecasted, or set make_forecast=False "
f"for that item."
)
raise InvalidBalancePlugsException(message)
# Determine index of array to increment. Array has structure of num plugs * num periods, with
# plugs in order of plug_keys and periods in order within the plugs
plug_idx = plug_keys.index(plug_key)
begin_plug_arr_idx = plug_idx * n_periods
arr_idx = begin_plug_arr_idx + i
x0[arr_idx] += abs(d) / PLUG_SCALE
[docs]class BalanceSheetBalancedException(Exception):
pass
def _check_for_invalid_system_of_equations(
eqs: List[Eq],
subs_dict: Dict[IndexedBase, float],
plug_solutions: Dict[IndexedBase, float],
to_solve_for: List[IndexedBase],
solve_exprs: List[Expr],
):
if len(to_solve_for) == len(solve_exprs):
# Equations seem valid, just return
return
# Invalid equations, figure out why
eq_lhs = {eq.lhs for eq in eqs}
subs_lhs = {key for key in subs_dict}
plugs_lhs = {key for key in plug_solutions}
message = f"Got {len(to_solve_for)} items to solve for with {len(solve_exprs)} equations. "
eq_subs_overlap = eq_lhs.intersection(subs_lhs)
if eq_subs_overlap:
message += f"Got {eq_subs_overlap} which overlap between the equations and the calculated values. "
eq_plugs_overlap = eq_lhs.intersection(plugs_lhs)
if eq_plugs_overlap:
message += f"Got {eq_plugs_overlap} which overlap between the equations and the plug values. "
subs_plugs_overlap = subs_lhs.intersection(plugs_lhs)
if subs_plugs_overlap:
message += f"Got {subs_plugs_overlap} which overlap between the calculated values and the plug values. "
raise InvalidForecastEquationException(message)
def _get_equations_reformed_for_needed_solutions(
eqs: Sequence[Eq],
all_hardcoded: Dict[IndexedBase, float],
config: StatementsConfigManager,
) -> List[Eq]:
new_eqs = []
for eq in eqs:
if eq.lhs in all_hardcoded:
# Got a calculated item which has also been set with make_forecast=True or as plug=True
# Solve the equation to see if there is another variable we can set as the lhs which
# has make_forecast=False and plug=False
selected_lhs: Optional[IndexedBase] = None
for sym in _get_indexed_symbols(eq.rhs):
if sym not in all_hardcoded:
selected_lhs = sym # type: ignore[assignment]
if selected_lhs is None:
# Invalid forecast, need to display useful message to the user to fix it.
# Need to get the original unsubbed equation, as possible variables the user could adjust might
# have been substituted out of the equation
key = str(eq.lhs.base)
orig_expr = config.expr_for(key)
orig_eq = Eq(eq.lhs, orig_expr)
possible_fix_strs = []
possible_symbols = _get_indexed_symbols(orig_eq)
for sym in possible_symbols:
sym_key = str(sym.base)
fix_str = (
f'\tstmts.config.update("{sym_key}", ["forecast_config", "make_forecast"], False)\n\t'
f'stmts.config.update("{sym_key}", ["forecast_config", "plug"], False)'
)
possible_fix_strs.append(fix_str)
possible_fix_str = "\nor,\n".join(possible_fix_strs)
raise InvalidForecastEquationException(
f"{eq.lhs} has been set with make_forecast=True or plug=True and yet it is a calculated "
f"item. Tried to re-express {orig_eq} in terms of another variable which is not forecasted or "
f"plugged but they all are. Set one of {_get_indexed_symbols(orig_eq)} "
f"with make_forecast=False and plug=False.\n\nPossible fixes:\n{possible_fix_str}"
)
# Another variable in the original equation is not forecasted/plugged. Re-express the equation in
# terms of that variable
solution = solve(eq, selected_lhs)[0]
new_eqs.append(Eq(selected_lhs, solution))
else:
new_eqs.append(eq)
return new_eqs