Source code for pantea.datasets.dataset

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterator, Optional, Protocol

from pantea.atoms.structure import Structure
from pantea.datasets.runner import RunnerDataSource
from pantea.logger import logger
from pantea.types import Dtype


[docs]class DataSourceInterface(Protocol): def __len__(self) -> int: ... def __getitem__(self, index: int) -> Structure: ...
[docs] def read_structures(self) -> Iterator[Structure]: ...
[docs]@dataclass class Dataset: """A container for Structure data with caching support.""" datasource: DataSourceInterface persist: bool cache: Dict[int, Structure] = field(default_factory=dict, repr=False)
[docs] @classmethod def from_runner( cls, filename: Path, persist: bool = False, dtype: Optional[Dtype] = None, ) -> Dataset: dataset = RunnerDataSource(filename, dtype) return cls(dataset, persist)
def __len__(self) -> int: return len(self.datasource) def __getitem__(self, index: int) -> Structure: """Read the desired structure, if possible from cache.""" if self.persist and (index in self.cache): return self.cache[index] else: logger.debug(f"loading structure({index=})") structure = self.datasource[index] if self.persist: self.cache[index] = structure return structure
[docs] def preload(self) -> None: """ Preload (cache) all the dataset structures into the memory. This ensures that any structure can be rapidly loaded from memory in subsequent operations. """ logger.info("Preloading (caching) all structures") self.persist = True try: structures: Iterator[Structure] = self.datasource.read_structures() for index, structure in enumerate(structures): self.cache[index] = structure except AttributeError: for index in range(len(self)): self.cache[index] = self.datasource[index]