"""
Module containing Quantity, which is used to represented deterministic values
"""
# =============================================================================
# IMPORTS
# =============================================================================
import numpy as np
import jax
import jax.numpy as jnp
import torch
import bellini
import pint
from bellini.api import utils
from pint.errors import DimensionalityError
from pint.compat import (
eq,
is_duck_array_type,
zero_or_nan,
)
# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class Quantity(pint.quantity.Quantity):
""" A class that describes physical quantity, which contains
numeric value and units.
"""
@staticmethod
def _convert_to_numpy(x):
if isinstance(x, (float, int)):
return np.array(x)
elif isinstance(x, (np.generic, np.ndarray)):
return x
elif isinstance(x, jnp.ndarray):
return np.array(x)
elif isinstance(x, torch.Tensor):
# TODO: do not require torch import ahead of time
return x.numpy()
print(type(x), x)
raise ValueError("input could not be converted to numpy!")
@staticmethod
def _convert_to_jnp(x):
if isinstance(x, (float, int)):
return jnp.array(x)
elif isinstance(x, (np.generic, np.ndarray)):
return jnp.array(x)
elif isinstance(x, jnp.ndarray):
return x
elif isinstance(x, torch.Tensor):
# TODO: do not require torch import ahead of time
return jnp.array(x)
raise ValueError("input could not be converted to jnp!")
def __new__(cls, value, unit=None, name=None):
assert not isinstance(value, Quantity), "value cannot be a Quantity"
if bellini.infer:
value = cls._convert_to_jnp(value)
else:
value = cls._convert_to_numpy(value)
if name is None:
name = repr(cls)
cls.name = name
return super().__new__(cls, value, unit)
[docs] def jnp(self):
r""" return self but jnp.ndarray """
value = self._convert_to_jnp(self.magnitude)
unit = self.units
instance = self.__new__(self.__class__, value, unit)
instance.name = self.name
return instance
[docs] def unitless(self):
r""" return self but unitless """
instance = self.__new__(self.__class__, self.magnitude)
instance.name = self.name
#print("unitless", type(instance))
return instance
def to_units(self, new_units, force=False):
try:
return self.to(new_units)
except DimensionalityError as e:
if not force:
print(f"cannot convert {self.units} to {new_units}. if you'd like to assign new units, use force=True")
raise e
instance = self.__new__(self.__class__, self.magnitude, new_units)
instance.name = self.name
return instance
def _build_graph(self):
import networkx as nx
g = nx.MultiDiGraph()
g.add_node(self, ntype='quantity', name=self.name)
self._g = g
return g
@property
def g(self):
if not hasattr(self, '_g'):
self._build_graph()
return self._g
def __add__(self, x):
if isinstance(x, (bellini.Distribution, bellini.Group)):
return NotImplemented
return super().__add__(x)
def __sub__(self, x):
if isinstance(x, (bellini.Distribution, bellini.Group)):
return NotImplemented
return super().__sub__(x)
def __mul__(self, x):
if isinstance(x, (bellini.Distribution, bellini.Group)):
return NotImplemented
return super().__mul__(x)
def __truediv__(self, x):
if isinstance(x, (bellini.Distribution, bellini.Group)):
return NotImplemented
return super().__truediv__(x)
def __pow__(self, x):
if isinstance(x, (bellini.Distribution, bellini.Group)):
return NotImplemented
return super().__pow__(x)
def __hash__(self):
self_base = self.to_base_units()
# TODO: faster way to hash an array?
# str(arr.sum()) + str((arr**2).sum()) is a possibility for large arrays
#if utils.is_arr(self.magnitude):
#print(type(self_base.magnitude))
if isinstance(self_base.magnitude, jax.interpreters.partial_eval.DynamicJaxprTracer):
return hash((self_base.__class__, self_base.magnitude.shape, self_base.units))
return hash((self_base.__class__, self_base.magnitude.tobytes(), self_base.units))
#return super().__hash__()
def __eq__(self, other):
def super_eq(self, other):
def bool_result(value):
nonlocal other
if not is_duck_array_type(type(self._magnitude)):
return value
if isinstance(other, Quantity):
other = other._magnitude
template, _ = np.broadcast_arrays(self._magnitude, other)
return np.full_like(template, fill_value=value, dtype=np.bool_)
# We compare to the base class of Quantity because
# each Quantity class is unique.
if not isinstance(other, Quantity):
if zero_or_nan(other, True):
# Handle the special case in which we compare to zero or NaN
# (or an array of zeros or NaNs)
if self._is_multiplicative:
# compare magnitude
return eq(self._magnitude, other, False)
else:
# compare the magnitude after converting the
# non-multiplicative quantity to base units
if self._REGISTRY.autoconvert_offset_to_baseunit:
return eq(self.to_base_units()._magnitude, other, False)
else:
raise OffsetUnitCalculusError(self._units)
if self.dimensionless:
return eq(
self._convert_magnitude_not_inplace(self.UnitsContainer()),
other,
False,
)
return bool_result(False)
if self._units == other._units:
return eq(self._magnitude, other._magnitude, False)
try:
return eq(
self._convert_magnitude_not_inplace(other._units),
other._magnitude,
False,
)
except DimensionalityError:
return bool_result(False)
is_eq = super_eq(self, other)
if utils.is_arr(is_eq):
return is_eq.all()
return is_eq