Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A test for using symbolic shapes for minformer. #25317

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Dec 6, 2024

The actual minformer code is from https://github.com/sholtodouglas/minformer/blob/main/minformer/model_test.py with as few modifications as possible.

The modifications needed to get the symbolic evaluations to work are marked with TODO.

To run the examples:

pytest --verbosity 1 -s tests/minformer_symbolic_shape_test.py

The test output:

tests/minformer_symbolic_shape_test.py::MinformerSymbolicShapesTest::test_inference_causal=False
weights shapes: Weights(layers=[Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,))], embedding=(V, D), gamma_final=(D,))
prefill_cache shapes: KVCache(k=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], v=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], lengths=(1,))
chunk_a shapes: (1, S)
segment_ids shapes: (1, S)
logits shapes: (1, S, V)
PASSED
tests/minformer_symbolic_shape_test.py::MinformerSymbolicShapesTest::test_inference_causal=True
weights shapes: Weights(layers=[Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,))], embedding=(V, D), gamma_final=(D,))
prefill_cache shapes: KVCache(k=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], v=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], lengths=(1,))
chunk_a shapes: (1, S)
segment_ids shapes: (1, S)
logits shapes: (1, S, V)
PASSED
tests/minformer_symbolic_shape_test.py::MinformerSymbolicShapesTest::test_overtrain_and_sample_simple_sequence SKIPPED (Works only on TPU)
tests/minformer_symbolic_shape_test.py::MinformerSymbolicShapesTest::test_training_use_attn_kernel=False
weights shapes: Weights(layers=[Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,)), Layer(q=(D, N, H), k=(D, N, H), v=(D, N, H), proj=(N, H, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(H,), k_gamma=(H,))], embedding=(V, D), gamma_final=(D,))
opt_state shapes: Weights(layers=[Layer(q=((D, N, H), (D, N, H), (D, N, H)), k=((D, N, H), (D, N, H), (D, N, H)), v=((D, N, H), (D, N, H), (D, N, H)), proj=((N, H, D), (N, H, D), (N, H, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((H,), (H,), (H,)), k_gamma=((H,), (H,), (H,))), Layer(q=((D, N, H), (D, N, H), (D, N, H)), k=((D, N, H), (D, N, H), (D, N, H)), v=((D, N, H), (D, N, H), (D, N, H)), proj=((N, H, D), (N, H, D), (N, H, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((H,), (H,), (H,)), k_gamma=((H,), (H,), (H,))), Layer(q=((D, N, H), (D, N, H), (D, N, H)), k=((D, N, H), (D, N, H), (D, N, H)), v=((D, N, H), (D, N, H), (D, N, H)), proj=((N, H, D), (N, H, D), (N, H, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((H,), (H,), (H,)), k_gamma=((H,), (H,), (H,))), Layer(q=((D, N, H), (D, N, H), (D, N, H)), k=((D, N, H), (D, N, H), (D, N, H)), v=((D, N, H), (D, N, H), (D, N, H)), proj=((N, H, D), (N, H, D), (N, H, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((H,), (H,), (H,)), k_gamma=((H,), (H,), (H,)))], embedding=((V, D), (V, D), (V, D)), gamma_final=((D,), (D,), (D,)))
prefill_cache shapes: KVCache(k=[(B, N, T, H), (B, N, T, H), (B, N, T, H), (B, N, T, H)], v=[(B, N, T, H), (B, N, T, H), (B, N, T, H), (B, N, T, H)], lengths=(B,))
loss shapes: ()
cache shapes: KVCache(k=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], v=[(1, N, T, H), (1, N, T, H), (1, N, T, H), (1, N, T, H)], lengths=(1,))
PASSED
tests/minformer_symbolic_shape_test.py::MinformerSymbolicShapesTest::test_training_use_attn_kernel=True
weights shapes: Weights(layers=[Layer(q=(D, N, 8), k=(D, N, 8), v=(D, N, 8), proj=(N, 8, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(8,), k_gamma=(8,)), Layer(q=(D, N, 8), k=(D, N, 8), v=(D, N, 8), proj=(N, 8, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(8,), k_gamma=(8,)), Layer(q=(D, N, 8), k=(D, N, 8), v=(D, N, 8), proj=(N, 8, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(8,), k_gamma=(8,)), Layer(q=(D, N, 8), k=(D, N, 8), v=(D, N, 8), proj=(N, 8, D), w1=(D, D*F), w2=(D*F, D), attn_gamma=(D,), ffn_gamma=(D,), q_gamma=(8,), k_gamma=(8,))], embedding=(V, D), gamma_final=(D,))
opt_state shapes: Weights(layers=[Layer(q=((D, N, 8), (D, N, 8), (D, N, 8)), k=((D, N, 8), (D, N, 8), (D, N, 8)), v=((D, N, 8), (D, N, 8), (D, N, 8)), proj=((N, 8, D), (N, 8, D), (N, 8, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((8,), (8,), (8,)), k_gamma=((8,), (8,), (8,))), Layer(q=((D, N, 8), (D, N, 8), (D, N, 8)), k=((D, N, 8), (D, N, 8), (D, N, 8)), v=((D, N, 8), (D, N, 8), (D, N, 8)), proj=((N, 8, D), (N, 8, D), (N, 8, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((8,), (8,), (8,)), k_gamma=((8,), (8,), (8,))), Layer(q=((D, N, 8), (D, N, 8), (D, N, 8)), k=((D, N, 8), (D, N, 8), (D, N, 8)), v=((D, N, 8), (D, N, 8), (D, N, 8)), proj=((N, 8, D), (N, 8, D), (N, 8, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((8,), (8,), (8,)), k_gamma=((8,), (8,), (8,))), Layer(q=((D, N, 8), (D, N, 8), (D, N, 8)), k=((D, N, 8), (D, N, 8), (D, N, 8)), v=((D, N, 8), (D, N, 8), (D, N, 8)), proj=((N, 8, D), (N, 8, D), (N, 8, D)), w1=((D, D*F), (D, D*F), (D, D*F)), w2=((D*F, D), (D*F, D), (D*F, D)), attn_gamma=((D,), (D,), (D,)), ffn_gamma=((D,), (D,), (D,)), q_gamma=((8,), (8,), (8,)), k_gamma=((8,), (8,), (8,)))], embedding=((V, D), (V, D), (V, D)), gamma_final=((D,), (D,), (D,)))
prefill_cache shapes: KVCache(k=[(B, N, T, 8), (B, N, T, 8), (B, N, T, 8), (B, N, T, 8)], v=[(B, N, T, 8), (B, N, T, 8), (B, N, T, 8), (B, N, T, 8)], lengths=(B,))
loss shapes: ()
cache shapes: KVCache(k=[(1, N, T, 8), (1, N, T, 8), (1, N, T, 8), (1, N, T, 8)], v=[(1, N, T, 8), (1, N, T, 8), (1, N, T, 8), (1, N, T, 8)], lengths=(1,))
PASSED

@gnecula gnecula force-pushed the minformer_symbolic branch 2 times, most recently from ae0c656 to 0c887fb Compare December 6, 2024 16:14
@gnecula gnecula self-assigned this Dec 6, 2024
@gnecula gnecula force-pushed the minformer_symbolic branch 3 times, most recently from 9955068 to a24da84 Compare December 11, 2024 15:59
The actual minformer code is from https://github.com/sholtodouglas/minformer/blob/main/minformer/model_test.py
with as few modifications as possible.

The modifications needed to get the symbolic evaluations to work
are marked with TODO.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant