state-spaces/mamba#771 GitHub Actions CI using EC2 GPU nodes
This PR adds 4 GitHub Actions workflows that install or test mamba_ssm
on EC2 GPU nodes (using Open-Athena/ec2-gha):
install.yaml
: installmamba_ssm
on an EC2 GPU instance (default:g4dn.xlarge
)installs.yaml
: runinstall.yaml
on 6 recent versions of Mamba (2.2.{0,1,2,3post2,4,5}
)test.yaml
: runmamba_ssm
tests on an EC2 GPU instance (g5
org6
series)tests.yaml
: runtest.yaml
on HEAD, on ag5.2xlarge
andg6.2xlarge
Both g5.2xlarge (A10G) and g6.2xlarge (L4) runs exhibited some bfloat16 precision failures with the original tolerances.
Resolution: Tests now pass with relaxed tolerances:
test_selective_state_update_with_batch_indices
: rtol=0.09, atol=0.096 (was rtol=0.06, atol=0.06)test_chunk_state_varlen
: rtol=0.01, atol=0.006 (was rtol=0.01, atol=0.003)
g5.2xlarge (A10G) - 2 failures
-
test_selective_state_update_with_batch_indices[2048-64-True-itype2]
(rtol=0.06, atol=0.06)- 2 out of 32,768 elements (0.006%) exceeded tolerance
- Worst cases:
expected=1.156, got=1.242, abs_diff=0.086, rel_diff=7.4%
expected=0.027, got=0.090, abs_diff=0.063, rel_diff=233%
-
test_chunk_state_varlen[128-1-dtype2]
(rtol=0.01, atol=0.003)- Max diff: 0.00546 (exceeded atol of 0.003)
g6.2xlarge (L4) - 3 failures
-
test_selective_state_update_with_batch_indices[2064-32-True-itype2]
(rtol=0.06, atol=0.06)- 1 out of 33,024 elements (0.003%) exceeded tolerance
- Worst case:
expected=0.318, got=0.236, abs_diff=0.082, rel_diff=25.8%
-
test_selective_state_update_with_batch_indices[2064-64-True-itype2]
(rtol=0.06, atol=0.06)- 4 out of 33,024 elements (0.012%) exceeded tolerance
- Worst cases:
expected=0.006, got=-0.089, abs_diff=0.095, rel_diff=1583%
(near-zero expected)expected=-1.109, got=-1.039, abs_diff=0.070, rel_diff=6.3%
expected=0.957, got=0.887, abs_diff=0.070, rel_diff=7.3%
-
test_selective_state_update_with_batch_indices[4096-64-True-itype2]
(rtol=0.06, atol=0.06)- 1 out of 65,536 elements (0.0015%) exceeded tolerance
- Worst case:
expected=-0.176, got=-0.250, abs_diff=0.074, rel_diff=42.0%
These failures affected only 0.0015-0.012% of tensor elements and are within expected bfloat16 precision limits.
pip install mamba_ssm==2.2.5
(sans --no-build-isolation
) succeeds, but older versions fail (cf. install#13)
I learned that it's important to get pre-built mamba_ssm
wheels (from GitHub Releases; they're not on PyPI):
pip install 2.2.5
job took 3m48s on 8/6, 52m on 8/8- The reason seems to be that Torch 2.8.0 was released on 8/6; 2.2.5 only has pre-built wheels for 2.4 through 2.7.
I originally hit issues pip install
ing mamba_ssm
on EC2 GPU nodes, and wanted to understand this comment better:
Try passing
--no-build-isolation
to pip if installation encounters difficulties either when building from source or installing from PyPi. Common pip complaints that can be resolved in this way include PyTorch versions, but other cases exist as well.
I made Open-Athena/ec2-gha for easier testing/verifying/MREs, and used it here in 2 GHAs.
I've set these GHA variables (on Open-Athena, but repo-level also OK):
AWS_REGION=us-east-1
AWS_ROLE=arn:aws:iam::066506852143:role/github-actions-role-1-c9ee23c
CLOUDWATCH_LOGS_GROUP=/aws/ec2/github-runners
EC2_INSTANCE_PROFILE=github-runner-ec2-profile-da09798
EC2_KEY_NAME=gha
EC2_LAUNCH_ROLE=arn:aws:iam::066506852143:role/github-actions-role-1-c9ee23c
EC2_SECURITY_GROUP_ID=sg-0eef00964cb375a64
See also example config scripts.