Source code for pantea.types

from dataclasses import dataclass

import jax.numpy as jnp
from jax._src.numpy.lax_numpy import _ScalarMeta
from jaxlib.xla_extension import ArrayImpl

Array = ArrayImpl
Dtype = _ScalarMeta
Element = str
Scalar = ArrayImpl


[docs]@dataclass class DataType: """ A configuration for array data type. It is globally used as default dtype for arrays, indices, etc. User can modify any default dtype, for example, via setting the global floating point precision (`FLOATX`) to a single (float32) or double (float64). """ FLOATX: Dtype = jnp.float64 INT: Dtype = jnp.int32 UINT: Dtype = jnp.uint32 INDEX: Dtype = jnp.int32
default_dtype = DataType()