Source code for mlnext.utils

"""Module with utility functions.
"""
import typing as T

import numpy as np

__all__ = [
    'truncate',
    'check_shape',
    'check_ndim',
    'check_size',
    'rename_keys',
    'RangeDict'
]


[docs]def truncate( *arrays: T.Tuple[np.ndarray, ...], axis: int = 0 ) -> T.Iterator[T.Tuple[np.ndarray, ...]]: """Truncates the arrays in each tuple to the minimum length of an array in the that tuple. Args: axis (int): Axis to of arrays to truncate. Yields: T.Iterator[T.Tuple[np.ndarray, ...]]: Iterator over tuple of arrays where each tuple has of arrays has the same length. Example: >>> arr1 = [np.ones((1, 1)), np.zeros((2, 1))] >>> arr2 = [np.zeros((2, 1)), np.ones((3, 1))] >>> np.hstack(list(mlnext.utils.truncate(*zip(arr1, arr2)))) ([1, 0, 0], [0, 1, 1]) """ for i, array in enumerate(arrays): if not isinstance(array, tuple): raise ValueError( f'Expected tuple or list but got {type(array)} for array at ' f'position {i}.') # find minimum length in tuple length = min(map(lambda x: np.shape(x)[axis], array)) # moveaxis creates a view on the array that moves axis to the front yield tuple(map(lambda x: np.moveaxis(x, axis, 0)[:length], array))
[docs]def check_shape( *arrays: np.ndarray, shape: T.Optional[T.Tuple[int, ...]] = None, exclude_axis: T.Optional[int] = None ): """Checks the shape of one or more arrays. If `shape` is not None, all arrays must match `shape`. Otherwise `shape` is set to the shape of the first array. With exclude_axis, an axis can be excluded from the check. Args: shape (T.Optional[T.Tuple[int, ...]], optional): Shape to match. If not defined, then the shape of the first array is taken. Defaults to None. exclude_axis (T.Optional[int], optional): Excludes an axis from the check. Shape must be defined without the axis. Defaults to None. Raises: ValueError: Raised if an array has the wrong shape. Example: >>> mlnext.utils.check_shape([np.zeros((1, 2)), np.ones((2, 2)), ... shape=(2,), exclude_axis=0) """ shapes = list(map(lambda x: x.shape, arrays)) if exclude_axis is not None: shapes = list(map(lambda x: tuple(np.delete(x, exclude_axis)), shapes)) if shape is None: shape = shapes[0] for i, shape_ in enumerate(shapes): if shape_ != shape: raise ValueError( f'Expected shape {shape} but got shape {shape_} for array at ' f'position {i} (exclude_axis: {exclude_axis}).')
[docs]def check_ndim( *arrays: np.ndarray, ndim: int, strict: bool = True ): """Checks whether each passed array has exactly `ndim` number of dimensions, if strict is false then the number of dimensions must be at most `ndim`. Args: ndim (int): Number of dimensions. strict (bool, optional): If true ndim must match exactly, otherwise at most. Raises: ValueError: Raised if an error does not match the ndim requirements. Example: >>> mlnext.utils.check_ndim(np.ones((1, 2, 3)), np.ones((3, 2, 1)), ... ndim=3) """ for i, arr in enumerate(arrays): if arr.ndim > ndim or (strict and arr.ndim < ndim): raise ValueError( f'Expected array of dimension {ndim}, but got {arr.ndim} for ' f'array at position {i}.')
[docs]def check_size( *arrays: np.ndarray, size: int, axis: int, strict: bool = True ): """Checks whether each array has exactly `size` elements along `axis`, if strict is false then it must be at most `size` elements. If strict is false and the axis is missing, then the array is ignored. Args: size (int): Number of elements along the axis. axis (int): Axis to check. strict (bool, optional): If true the check is exact, otherwise at most. Raises: ValueError: Raised when the axis is missing or the size requirement is not fulfilled. Example: >>> mlnext.utils.check_size(np.ones((10, 1)), np.ones((10, 3)), ... size=10, axis=0) """ for i, arr in enumerate(arrays): if arr.ndim - 1 < axis: if strict: raise ValueError( f'Array at position {i} is missing axis {axis}.') else: continue if (shape := arr.shape[axis]) > size or (strict and shape < size): raise ValueError( f'Expected axis {axis} of array to be of size {size}, but got ' f'{shape} for array at position {i}.')
[docs]def rename_keys( mapping: T.Dict[str, T.Any], *, prefix: T.Optional[str] = None, suffix: T.Optional[str] = None ) -> T.Dict[str, T.Any]: """Renames every key in `mapping` with a `prefix` and/or `suffix`. Args: mapping (T.Dict): Mapping. prefix (str, optional): String to prepend. Defaults to None. suffix (str, optional): String to append. Defaults to None. Returns: T.Dict: Returns the updated mapping. """ return { f'{prefix or ""}{k}{suffix or ""}': v for k, v in mapping.items() }
[docs]def flatten( mapping: T.Mapping[str, T.Any], *, prefix: str = '', sep: str = '.', flatten_list: bool = True ) -> T.Mapping[str, T.Any]: """Turns a nested mapping into a flattened mapping. Args: mapping (T.Mapping[str, T.Any]): Mapping to flatten. prefix (str): Prefix to preprend to the key. sep (str): Seperator of flattened keys. flatten_list (bool): Whether to flatten lists. Returns: T.Mapping[str, T.Any]: Returns a (flattened) mapping. Example: >>> flatten({ ... 'flat1': 1, ... 'dict1': {'c': 1, 'd': 2}, ... 'nested': {'e': {'c': 1, 'd': 2}, 'd': 2}, ... 'list1': [1, 2], ... 'nested_list': [{'1': 1}] ... }) { 'flat1': 1, 'dict1.c': 1, 'dict1.d': 2, 'nested.e.c': 1, 'nested.e.d': 2, 'nested.d': 2, 'list1.0': 1, 'list1.1': 2, 'nested_list.0.1': 1 } """ items: T.List[T.Tuple[str, T.Any]] = [] for k, v in mapping.items(): key = f'{sep}'.join([prefix, k]) if prefix else k if isinstance(v, T.Mapping): items.extend(flatten( v, prefix=key, sep=sep, flatten_list=flatten_list ).items()) elif isinstance(v, list) and flatten_list: for i, v in enumerate(v): items.extend(flatten( {str(i): v}, prefix=key, sep=sep, flatten_list=flatten_list ).items()) else: items.append((key, v)) return dict(items)
[docs]class RangeDict(dict): """Dictionary that accepts range keys. """ def __getitem__(self, k__: int) -> T.Any: """Gets an item by key `k__` if it is in range of a key. Args: k__ (int): Key. Raises: KeyError: Raised if no range match was found. Returns: T.Any: Returns the item. """ for key in self.keys(): if k__ in key: return super().__getitem__(key) raise KeyError(k__)