"""Module with utility functions."""
import typing as T
import numpy as np
__all__ = [
'truncate',
'check_shape',
'check_ndim',
'check_size',
'rename_keys',
'RangeDict',
'flatten',
'convert_sequences',
]
[docs]
def truncate(
*arrays: T.Tuple[np.ndarray, ...], axis: int = 0
) -> T.Iterator[T.Tuple[np.ndarray, ...]]:
"""Truncates the arrays in each tuple to the minimum length of an array in
the that tuple.
Args:
axis (int): Axis to of arrays to truncate.
Yields:
T.Iterator[T.Tuple[np.ndarray, ...]]: Iterator over tuple of arrays
where each tuple has of arrays has the same length.
Example:
>>> arr1 = [np.ones((1, 1)), np.zeros((2, 1))]
>>> arr2 = [np.zeros((2, 1)), np.ones((3, 1))]
>>> np.hstack(list(mlnext.utils.truncate(*zip(arr1, arr2))))
([1, 0, 0], [0, 1, 1])
"""
for i, array in enumerate(arrays):
if not isinstance(array, tuple):
raise ValueError(
f'Expected tuple or list but got {type(array)} for array at '
f'position {i}.'
)
# find minimum length in tuple
length = min(map(lambda x: np.shape(x)[axis], array))
# moveaxis creates a view on the array that moves axis to the front
yield tuple(map(lambda x: np.moveaxis(x, axis, 0)[:length], array))
[docs]
def check_shape(
*arrays: np.ndarray,
shape: T.Optional[T.Tuple[int, ...]] = None,
exclude_axis: T.Optional[int] = None,
):
"""Checks the shape of one or more arrays. If `shape` is not None, all
arrays must match `shape`. Otherwise `shape` is set to the shape of the
first array. With exclude_axis, an axis can be excluded from the check.
Args:
shape (T.Optional[T.Tuple[int, ...]], optional): Shape to match. If not
defined, then the shape of the first array is taken. Defaults to
None.
exclude_axis (T.Optional[int], optional): Excludes an axis from the
check. Shape must be defined without the axis. Defaults to None.
Raises:
ValueError: Raised if an array has the wrong shape.
Example:
>>> mlnext.utils.check_shape([np.zeros((1, 2)), np.ones((2, 2)),
... shape=(2,), exclude_axis=0)
"""
shapes = list(map(lambda x: x.shape, arrays))
if exclude_axis is not None:
shapes = list(
map(
lambda x: tuple(
np.delete(x, exclude_axis).astype('int').tolist()
),
shapes,
)
)
if shape is None:
shape = shapes[0]
for i, shape_ in enumerate(shapes):
if shape_ != shape:
raise ValueError(
f'Expected shape {shape} but got shape {shape_} for array at '
f'position {i} (exclude_axis: {exclude_axis}).'
)
[docs]
def check_ndim(
*arrays: np.ndarray,
ndim: int,
strict: bool = True,
):
"""Checks whether each passed array has exactly `ndim` number of
dimensions, if strict is false then the number of dimensions must be at
most `ndim`.
Args:
ndim (int): Number of dimensions.
strict (bool, optional): If true ndim must match exactly, otherwise
at most.
Raises:
ValueError: Raised if an error does not match the ndim requirements.
Example:
>>> mlnext.utils.check_ndim(np.ones((1, 2, 3)), np.ones((3, 2, 1)),
... ndim=3)
"""
for i, arr in enumerate(arrays):
if arr.ndim > ndim or (strict and arr.ndim < ndim):
raise ValueError(
f'Expected array of dimension {ndim}, but got {arr.ndim} for '
f'array at position {i}.'
)
[docs]
def check_size(
*arrays: np.ndarray,
size: int,
axis: int,
strict: bool = True,
):
"""Checks whether each array has exactly `size` elements along `axis`,
if strict is false then it must be at most `size` elements. If strict is
false and the axis is missing, then the array is ignored.
Args:
size (int): Number of elements along the axis.
axis (int): Axis to check.
strict (bool, optional): If true the check is exact, otherwise at most.
Raises:
ValueError: Raised when the axis is missing or the size requirement
is not fulfilled.
Example:
>>> mlnext.utils.check_size(np.ones((10, 1)), np.ones((10, 3)),
... size=10, axis=0)
"""
for i, arr in enumerate(arrays):
if arr.ndim - 1 < axis:
if strict:
raise ValueError(
f'Array at position {i} is missing axis {axis}.'
)
else:
continue
if (shape := arr.shape[axis]) > size or (strict and shape < size):
raise ValueError(
f'Expected axis {axis} of array to be of size {size}, but got '
f'{shape} for array at position {i}.'
)
[docs]
def rename_keys(
mapping: T.Dict[str, T.Any],
*,
prefix: T.Optional[str] = None,
suffix: T.Optional[str] = None,
) -> T.Dict[str, T.Any]:
"""Renames every key in `mapping` with a `prefix` and/or `suffix`.
Args:
mapping (T.Dict): Mapping.
prefix (str, optional): String to prepend. Defaults to None.
suffix (str, optional): String to append. Defaults to None.
Returns:
T.Dict: Returns the updated mapping.
"""
return {f'{prefix or ""}{k}{suffix or ""}': v for k, v in mapping.items()}
[docs]
def flatten(
mapping: T.Mapping[str, T.Any],
*,
prefix: str = '',
sep: str = '.',
flatten_list: bool = True,
) -> T.Mapping[str, T.Any]:
"""Turns a nested mapping into a flattened mapping.
Args:
mapping (T.Mapping[str, T.Any]): Mapping to flatten.
prefix (str): Prefix to preprend to the key.
sep (str): Seperator of flattened keys.
flatten_list (bool): Whether to flatten lists.
Returns:
T.Mapping[str, T.Any]: Returns a (flattened) mapping.
Example:
>>> flatten({
... 'flat1': 1,
... 'dict1': {'c': 1, 'd': 2},
... 'nested': {'e': {'c': 1, 'd': 2}, 'd': 2},
... 'list1': [1, 2],
... 'nested_list': [{'1': 1}]
... })
{
'flat1': 1,
'dict1.c': 1,
'dict1.d': 2,
'nested.e.c': 1,
'nested.e.d': 2,
'nested.d': 2,
'list1.0': 1,
'list1.1': 2,
'nested_list.0.1': 1
}
"""
items: T.List[T.Tuple[str, T.Any]] = []
for k, v in mapping.items():
key = f'{sep}'.join([prefix, k]) if prefix else k
if isinstance(v, T.Mapping):
items.extend(
flatten(
v, prefix=key, sep=sep, flatten_list=flatten_list
).items()
)
elif isinstance(v, list) and flatten_list:
for i, v in enumerate(v):
items.extend(
flatten(
{str(i): v},
prefix=key,
sep=sep,
flatten_list=flatten_list,
).items()
)
else:
items.append((key, v))
return dict(items)
[docs]
def convert_sequences(
mapping: T.Dict[str, T.Any],
ignore_sequence_types: T.Tuple[T.Type, ...] = tuple(),
) -> T.Dict[str, T.Any]:
"""Converts sequences in ``mapping`` to a dict by the index.
Args:
mapping (T.Dict[str, T.Any]): Input dict.
ignore_sequence_types: Sequence types to leave as-is.
Returns:
T.Dict[str, T.Any]: Returns the new dict.
Example:
>>> mapping = {'a': {'b':[{'c': 0, 'd': {'e': 1}},'f']}}
>>> convert_sequence(mapping)
{'a': {'b': {'0': {'c': 0, 'd': {'e': 1}}, '1': 'f'}}}
"""
new_mapping: T.Dict[str, T.Any] = {}
for k, v in mapping.items():
if isinstance(v, dict):
new_mapping[k] = convert_sequences(v, ignore_sequence_types)
elif isinstance(v, T.Sequence) and not isinstance(
v, (str, *ignore_sequence_types)
):
new_mapping[k] = convert_sequences(
{str(i): v for i, v in enumerate(v)}, ignore_sequence_types
)
else:
new_mapping[k] = v
return new_mapping
[docs]
class RangeDict(dict):
"""Dictionary that accepts range keys."""
def __getitem__(self, k__: int) -> T.Any:
"""Gets an item by key `k__` if it is in range of a key.
Args:
k__ (int): Key.
Raises:
KeyError: Raised if no range match was found.
Returns:
T.Any: Returns the item.
"""
for key in self.keys():
if k__ in key:
return super().__getitem__(key)
raise KeyError(k__)