Source code for pd_utils.plot

from typing import List, Optional, Dict, Any

import pandas as pd
import matplotlib.pyplot as plt

# One liner to convert 1 to 1st, 2 to 2nd, etc.
from pandas.plotting._matplotlib.style import get_standard_colors

ordinal = lambda n: "%d%s" % (n,"tsnrhtdd"[(n//10%10!=1)*(n%10<4)*n%10::4])

[docs]def plot_multi_axis(df: pd.DataFrame, cols: Optional[List[str]] = None, spacing: float = .1, colored_axes: bool = True, axis_locations_in_legend: bool = True, legend_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> plt.Axes: """ Plot multiple series with different y-axes Adapted from https://stackoverflow.com/a/50655786 :param df: Data to be plotted :param cols: subset of columns to plot :param spacing: Amount of space between y-axes beyond the two which are on the sides of the box :param colored_axes: Whether to make axis labels and ticks colored the same as the line on the graph :param axis_locations_in_legend: Whether to add to the legend which axis corresponds to which plot :param legend_kwargs: Keyword arguments to pass to ax.legend :param kwargs: df.plot kwargs :return: """ if cols is None: cols = df.columns if len(cols) == 0: raise ValueError('if cols are passed, must not be an empty list') if legend_kwargs is None: legend_kwargs = {} if axis_locations_in_legend: rename_dict: Dict[str, str] = {} for i, col in enumerate(cols): if i == 0: base_position = 'left' extra_position = '' else: base_position = 'right' extra_position = f', {ordinal(i)}' position = base_position + extra_position new_name = f'{col} ({position})' rename_dict[col] = new_name df = df.rename(columns=rename_dict) col_labels = cols cols = list(rename_dict.values()) else: col_labels = cols # Get default color style from pandas - can be changed to any other color list colors = get_standard_colors(num_colors=len(cols)) # First axis color = colors[0] ax = df.loc[:, cols[0]].plot(label=cols[0], color=color, **kwargs) ax.set_ylabel(ylabel=col_labels[0]) if colored_axes: ax.yaxis.label.set_color(color) ax.tick_params(axis='y', colors=color) lines, labels = ax.get_legend_handles_labels() for n in range(1, len(cols)): color = colors[n % len(colors)] # Multiple y-axes ax_new: plt.Axes = ax.twinx() ax_new.spines['right'].set_position(('axes', 1 + spacing * (n - 1))) df.loc[:, cols[n]].plot(ax=ax_new, label=cols[n], color=color, **kwargs) ax_new.set_ylabel(ylabel=col_labels[n]) if colored_axes: ax_new.yaxis.label.set_color(color) ax_new.tick_params(axis='y', colors=color) # Proper legend position line, label = ax_new.get_legend_handles_labels() lines += line labels += label ax.legend(lines, labels, **legend_kwargs) return ax