Skip to content

contrastive

contrastive

contrastive.py Padded contrastive dataset: paired positive/negative sequences with masks.

ContrastivePaddedDataset(spec: ExecutionSpec, block_size: int, name: str, suffix: str = '')

Bases: Dataset

get_batch(batch_size: int, split: str = 'train', deterministic_key: Optional[int] = None) -> dict[str, np.ndarray]

Get batches from contrastive padded dataset with masks.