Source code for pantea.utils.batch

import math
from typing import Generator

from pantea.types import Array


[docs]def create_batch( array: Array, batch_size: int, ) -> Generator[Array, None, None]: """ Create baches of the input array. :param array: input array :type array: Array :param batch_size: desired batch size :type batch_size: int :yield: a batch of input array :rtype: Generator[Array, None, None] """ n_batches = int(math.ceil(len(array) / batch_size)) # type: ignore for i in range(n_batches): yield array[i * batch_size : (i + 1) * batch_size, ...]