Source code for pantea.models.nn.initializer
from typing import Tuple
from flax import linen as nn
from pantea.types import Array, Dtype
KeyArray = Array
[docs]class UniformInitializer:
"""Custom uniform initializer for the FLAX model."""
def __init__(self, weights_range: Tuple[float, float]) -> None:
self.weights_range = weights_range
self.initializer = nn.initializers.uniform(
self.weights_range[1] - self.weights_range[0]
)
def __call__(self, rng: KeyArray, shape: Tuple[int, ...], dtype: Dtype) -> Array:
return self.initializer(rng, shape, dtype) + self.weights_range[0]