scratchbubbles
scratchbubbles
¶
Scratchubbbles: thoughtbubbles except we can fork into any other token.
Scratchbubbles
¶
Bases: Thoughtbubbles
left_attn_average(input_seq_len: int, x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array) -> jax.Array
¶
Averaging where rightmost token of each idx acts as a query over a small masked attention which is only its left children.
unembed_forked(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, input_seq_len: int) -> jax.Array
¶
Compute output logits for forked tokens using residual averaging.
key = jax.random.PRNGKey(0) x = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2, job.model.n_embd)) cumulative_scores = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2)) token_index = jnp.arange(job.args.block_size)[None, :].repeat(2, axis=1).repeat(job.args.per_device_batch_size, axis=0) input_seq_len = job.args.block_size
plot(intermediates: Any) -> Dict[str, Any]
staticmethod
¶
intermediates -> [figure]
pca(x: jax.Array, n_components: int) -> jax.Array
¶
x: (N, D) -> (N, n_components). Centered SVD.
vectors_to_colors(embeddings: jax.Array) -> jax.Array
¶
(layers, seq_len, hidden) -> (layers, seq_len, 3) RGB.