Skip to content

huggingface

huggingface

huggingface compatibility layer I can't believe this is not butter.

Note that this codepath is not used for Qwen/Llama/Pythia model loads because they use native Linen implementations instead of torchax tracing.

This is for best-effort tracing through arbitrary Huggingface models.

HFCompat

Bases: Module

loss(logits: jax.Array, targets: jax.Array) -> jax.Array

Compute cross-entropy loss given logits and targets.