#!/usr/bin/env python
import numbers
import warnings
import numpy as np
import supervillain.h5.extendable as extendable
import logging
logger = logging.getLogger(__name__)
def resolve_batch_cls(tag):
r'''
Look up the element class for a stored :class:`Batch` column.
Parameters
----------
tag: str
Value of ``cls.__batch_tag__`` written to HDF5. An empty string means
a plain ndarray column (``cls is None``).
Returns
-------
type or None
The element class, or ``None`` when ``tag`` is empty.
Raises
------
ValueError
If ``tag`` is not registered.
'''
if not tag:
return None
if tag == 'Form':
from supervillain.lattice import Form
return Form
raise ValueError(f'Unknown Batch element class tag {tag!r}')
[docs]class Batch:
r'''
A column of MCMC draws stored as :class:`~supervillain.h5.extendable.array`
with shape ``(draw, …)``.
``extendable`` is for storage only; computation uses ``batch[i]`` (a scalar,
ndarray slice, or ``cls``-wrapped element such as :class:`~supervillain.lattice.Form`).
For whole-column ndarray operations use :attr:`array`.
Parameters
----------
draws_or_data:
If an ``int``, allocate a new zeroed column of that many draws.
Otherwise wrap existing batched data (draw axis must be 0).
cls:
Optional element class (:class:`~supervillain.lattice.Form`, etc.).
``None`` for plain ndarray / scalar columns.
shape:
Spatial shape when ``cls`` is ``None`` and ``draws_or_data`` is an ``int``.
dtype:
Column dtype. When allocating a new column (``draws_or_data`` is an
``int``) it defaults to ``float``. When wrapping existing data it
defaults to the data's own dtype; if given explicitly it must be able
to hold the data without loss, otherwise a :exc:`TypeError` is raised
(a lossy cast such as complex→float or float→int is rejected rather
than silently dropping data).
item_kwargs:
Column-constant keyword arguments passed to ``cls`` on each draw
(e.g. ``degree``, ``lattice`` for :class:`~supervillain.lattice.Form`).
'''
def __init__(self, draws_or_data, *, cls=None, shape=None, dtype=None, **item_kwargs):
if isinstance(draws_or_data, numbers.Integral) and not isinstance(draws_or_data, bool):
draws = int(draws_or_data)
if cls is not None:
spatial = cls.spatial_shape(**item_kwargs)
elif shape is None:
raise ValueError('Batch(draws, …) requires shape= when cls is None.')
else:
spatial = shape
# Allocating a fresh zeroed column; default to float when unspecified.
arr = np.zeros((draws,) + spatial, dtype=float if dtype is None else dtype)
else:
# Wrapping existing data: keep its dtype unless a compatible one is requested.
arr = np.asarray(draws_or_data) if dtype is None else self._checked_array(draws_or_data, dtype)
self._data = self._as_extendable(arr)
self.cls = cls
r'''The element class used to wrap each draw, or ``None`` for plain arrays.'''
self.dtype = self._data.dtype
r'''The column dtype.'''
self._item_kwargs = item_kwargs
r'''Column-constant keyword arguments passed to ``cls`` when indexing a single draw.'''
@staticmethod
def _as_extendable(arr):
r'''
Ensure ``arr`` is stored as :class:`~supervillain.h5.extendable.array`.
'''
if isinstance(arr, extendable.array):
return arr
return extendable.array(arr)
[docs] @classmethod
def from_data(cls, data, *, dtype=None, **kwargs):
r'''
Construct a :class:`Batch` from existing batched data.
Parameters
----------
data: array_like
Data whose zeroth axis is the draw index.
dtype:
Optional dtype override when wrapping ``data``.
kwargs:
Forwarded to :meth:`__init__` (e.g. ``cls``, ``degree``, ``lattice``).
Returns
-------
Batch
A batch wrapping ``data``.
'''
return cls(data, dtype=dtype, **kwargs)
@property
def array(self):
r'''
The resizable storage column (``extendable.array``, shape ``(draw, …)``).
Use when you already have a :class:`Batch`. Prefer ``batch[i]`` for one
draw; use :meth:`as_array` when a value might still be a legacy column.
'''
return self._data
[docs] @staticmethod
def as_array(column):
r'''
Unwrap a column for NumPy analysis.
:class:`Batch` → :attr:`~Batch.array`; anything else passes through
(legacy ``extendable.array``, plain ndarray). Use at boundaries where
the static type is unknown — not on attributes known to be ``Batch``.
'''
return column.array if isinstance(column, Batch) else column
@property
def shape(self):
r'''
Returns
-------
tuple
Shape of the underlying storage array, ``(draw, …)``.
'''
return self._data.shape
[docs] def __len__(self):
r'''
Returns
-------
int
Number of draws (length of axis 0).
'''
return len(self._data)
[docs] def __getitem__(self, index):
r'''
Parameters
----------
index: int, slice, or tuple
If an ``int``, return one draw (a scalar, ndarray slice, or ``cls`` instance).
If a ``slice``, return a new :class:`Batch` sharing metadata.
Otherwise delegate fancy indexing to the underlying storage array.
Returns
-------
scalar, ndarray, Form, or Batch
One element, a sub-batch, or a indexed view of the storage array.
'''
if isinstance(index, numbers.Integral) and not isinstance(index, bool):
sliced = self._data[index]
if self.cls is None:
return sliced
return self.cls(sliced, dtype=self.dtype, **self._item_kwargs)
if type(index) is slice:
return Batch(
self._data[index],
cls=self.cls,
dtype=self.dtype,
**self._item_kwargs,
)
return self._data[index]
[docs] def __setitem__(self, index, item):
r'''
Parameters
----------
index:
Draw index or indices to overwrite (numpy indexing).
item:
Value to store, coerced to :attr:`dtype`.
'''
self._data[index] = self._coerce_item(item)
@staticmethod
def _checked_array(data, dtype):
r'''
Cast ``data`` to ``dtype``, raising rather than silently dropping data. The cast
is allowed exactly when it preserves every value, so an integer-valued float column
may be stored as ``int`` (no information lost) while ``2.7 → int`` or a complex
value with a nonzero imaginary part → ``float`` is rejected. Shared by construction
and per-draw writes so both paths enforce the same contract.
'''
arr = np.asarray(data)
dtype = np.dtype(dtype)
if arr.dtype == dtype:
return arr
with warnings.catch_warnings():
warnings.simplefilter('ignore') # we re-check the cast ourselves below
out = arr.astype(dtype)
if not np.array_equal(out, arr):
raise TypeError(
f'Batch cannot store {arr.dtype} data as {dtype} without loss '
f'(the values do not round-trip); convert it explicitly first '
f'(e.g. data.real or data.round().astype(...)).'
)
return out
def _coerce_item(self, item):
r'''
Coerce a generator return value to the column :attr:`dtype` for storage,
raising rather than silently dropping data on a lossy cast.
'''
return self._checked_array(item, self.dtype)
[docs] def __iter__(self):
r'''
Yield one draw per step (same objects as ``batch[i]`` for integer ``i``).
'''
for i in range(len(self)):
yield self[i]
[docs] def extend_h5(self, group):
r'''
Append this batch's draws to an on-disk column.
The ``group`` must be an HDF5 group produced by the ``batch`` write
strategy (it must contain a resizable ``data`` dataset).
Parameters
----------
group: h5py.Group
Group storing the batch column to extend.
'''
from supervillain.h5.extendable import strategy as extendable_strategy
extendable_strategy.extend(group, 'data', self._data)
def __repr__(self):
r'''
Returns
-------
str
A short summary of shape, element class, and dtype.
'''
cls_name = self.cls.__name__ if self.cls is not None else 'ndarray'
return f'Batch(shape={self.shape}, cls={cls_name}, dtype={self.dtype})'