import copy
from typing import List, Sequence, Optional, Tuple
from typing_extensions import TypedDict
from pyexlatex.table.models.panels.grid import GridShape
from pyexlatex.table.models.labels.table import LabelTable
from pyexlatex.table.models.labels.collection import LabelCollection
from pyexlatex.table.models.table.section import TableSection
[docs]def common_column_labels(grid: GridShape, use_object_equality=True, enforce_label_order=True):
axis = 1 # columns
all_column_ints = list(range(grid.shape[1]))
return _selected_common_labels_for_axis(
grid,
selections=all_column_ints,
axis=axis,
use_object_equality=use_object_equality,
enforce_label_order=enforce_label_order
)
[docs]def common_row_labels(grid: GridShape, use_object_equality=True, enforce_label_order=True):
axis = 0 # rows
all_row_ints = list(range(grid.shape[0]))
return _selected_common_labels_for_axis(
grid,
selections=all_row_ints,
axis=axis,
use_object_equality=use_object_equality,
enforce_label_order=enforce_label_order
)
def _selected_common_labels_for_axis(grid: GridShape, selections: Sequence[int]=(0,), axis: int=0, use_object_equality=True,
enforce_label_order=True):
common_label_tables: List[LabelTable] = []
for i in selections:
common_label_tables.append(
_common_labels(
grid,
i,
axis=axis,
use_object_equality=use_object_equality,
enforce_label_order=enforce_label_order
)
)
non_none_tables = [table for table in common_label_tables if table is not None]
if non_none_tables == []:
return None
return non_none_tables
def _common_labels(grid: GridShape, num: int, axis: int=0, use_object_equality=True,
enforce_label_order=True):
subgrid = _get_subgrid(
grid=grid,
num=num,
axis=axis
)
label_attr = _get_label_attr(axis=axis)
label_tables: List[Optional[LabelTable]] = []
label_table: Optional[LabelTable]
for section in subgrid:
if isinstance(section, LabelTable):
# A label table directly in the grid
label_table = section
# If we are extracting row labels and there are row labels attached to DataTables in the grid
if axis == 0 and grid.data_has_row_labels:
# The row label part of a label table would be the first value
collection = LabelCollection([label_table[0][0]])
label_table = LabelTable(label_collections=[collection])
elif axis == 1:
# If we are extracting column labels, the entire label table is what we are extracting
pass
else:
label_table = None
else:
# Got some other section such as DataTable, extract labels from it
label_table = getattr(section, label_attr, None)
label_tables.append(label_table)
# first labels missing, no consolidation to be done, consolidated labels are None
if label_tables[0] is None:
return None
if len(label_tables) == 1:
# Only one table, so it must be common
return label_tables[0]
common_label_table = LabelTable([])
for i, label_collection in enumerate(label_tables[0]):
stored_match = False # only want to add each matched collection once. use boolean to track
for label_table in label_tables[1:]:
# If there is a corresponding label table and it has this index label collection
if label_table is not None and i < len(label_table.label_collections):
match = _compare_label_collections(
label_collection,
label_table[i],
use_object_equality=use_object_equality
)
else:
# No labels, nothing to consolidate
match = False
if match:
if not stored_match:
common_label_table.append(label_collection)
stored_match = True
else:
if enforce_label_order:
break # as soon as one label collection doesn't match, stop consolidating
else:
continue # don't worry about non-match, continue consolidating
if common_label_table.is_empty:
return None
return common_label_table
[docs]class LabelsRemoved(TypedDict):
columns: List[Tuple[int, int]]
rows: List[Tuple[int, int]]
[docs]def remove_label_collections_from_grid(grid: GridShape, column_labels: List[LabelTable] = None,
row_labels: List[LabelTable] = None, use_object_equality=True) -> LabelsRemoved:
column_indices: List[Tuple[int, int]] = []
row_indices: List[Tuple[int, int]] = []
for row_num, row in enumerate(grid):
for col_num, section in enumerate(row):
if column_labels is not None:
for label_table in column_labels:
lt = copy.deepcopy(label_table)
removed = _remove_label_collections(
section,
lt,
axis=1,
use_object_equality=use_object_equality,
inplace=True
)
if removed:
column_indices.append((row_num, col_num))
if row_labels is not None:
for label_table in row_labels:
lt = copy.deepcopy(label_table)
removed = _remove_label_collections(
section,
lt,
axis=0,
use_object_equality=use_object_equality,
inplace=True
)
if removed:
row_indices.append((row_num, col_num))
return LabelsRemoved(columns=column_indices, rows=row_indices)
def _remove_label_collections(section: TableSection, label_table: LabelTable, axis: int=0,
use_object_equality=True, inplace=False) -> Tuple[TableSection, bool]:
label_attr = _get_label_attr(axis=axis)
removed = False
# Handle if passed section is already a label table
if isinstance(section, LabelTable):
# If row
if axis == 0:
# If first row label matches
if section.begins_with(label_table[0][0].value):
# Remove that first label
section._label_collections[0].values.pop(0)
removed = True
# If column
else:
# If entire set of labels matches
if section.matches(label_table):
# Remove all these labels
section.label_collections = []
removed = True
# Handle section not having labels for this axis
if not hasattr(section, label_attr):
return section, removed
# Now has labels for this axis. Create a copy to avoid modifying original
if not inplace:
section = copy.deepcopy(section)
for label_collection in label_table:
section_label_table: LabelTable = copy.deepcopy(getattr(section, label_attr, []))
if section_label_table is not None:
for section_label_collection in section_label_table:
match = section_label_collection.is_subset_of(label_collection)
if match:
section_label_table.remove(section_label_collection)
setattr(section, label_attr, section_label_table)
removed = True
# once all label collections have been removed, remove table
_remove_empty_label_table_from_section(section, label_attr)
# Recreate rows in section
section._recreate_rows_if_created()
return section, removed
def _remove_empty_label_table_from_section(section: TableSection, label_attr: str):
section_label_table = getattr(section, label_attr, False)
if section_label_table and section_label_table.label_collections == []:
setattr(section, label_attr, None)
def _get_label_attr(axis: int=0):
# select rows
if axis == 0:
return '_row_labels'
# select columns
elif axis == 1:
return '_column_labels'
else:
raise ValueError(f'axis must be 0 or 1, got {axis}')
def _get_subgrid(grid: GridShape, num: int, axis: int=0):
# select rows
if axis == 0:
return _grid_if_not(grid[num])
# select columns
elif axis == 1:
return _grid_if_not(grid[:, num])
else:
raise ValueError(f'axis must be 0 or 1, got {axis}')
def _grid_if_not(ambiguous_grid):
if isinstance(ambiguous_grid, GridShape):
return ambiguous_grid
else:
return GridShape([ambiguous_grid])
def _compare_label_collections(collection1: LabelCollection, collection2: LabelCollection, use_object_equality=True):
if use_object_equality:
return collection1 == collection2
else:
return collection1.matches(collection2)
def _add_to_label_table_if_not_in_label_table(label_table: LabelTable, label_collection: LabelCollection,
use_object_equality=True):
"""
Note: inplace
"""
# don't want to keep adding match over and over. need to check if match is already
# stored in the common label table. must check two different ways depending on whether
# we are using object equality or string consolidation
if use_object_equality and label_collection not in label_table:
label_table.append(label_collection)
if (not use_object_equality) and (not label_table.contains(label_collection)):
label_table.append(label_collection)