Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/install.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: "Test pip install"
on:
workflow_dispatch:
inputs:
mamba_version:
description: "Mamba version to test"
required: true
type: string
default: "2.2.5"
python:
description: "Python version to use"
required: false
type: string
default: "3.11.13"
workflow_call:
inputs:
mamba_version:
description: "Mamba version to test"
required: true
type: string
python:
description: "Python version to use"
required: false
type: string
default: "3.11.13"
permissions:
id-token: write
contents: read
jobs:
ec2:
uses: Open-Athena/ec2-gha/.github/workflows/runner.yml@v2
secrets: inherit
with:
ec2_instance_type: g4dn.xlarge
ec2_image_id: ami-0aee7b90d684e107d # Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20250623
instance_name: "$repo/$name==${{ inputs.mamba_version }} (#$run_number)"
install:
name: Test mamba_ssm==${{ inputs.mamba_version }}
needs: ec2
runs-on: ${{ needs.ec2.outputs.id }}
steps:
- name: Setup Python environment
run: |
# Set up environment for GitHub Actions to use conda env
echo "/opt/conda/envs/pytorch/bin" >> $GITHUB_PATH
echo "CONDA_DEFAULT_ENV=pytorch" >> $GITHUB_ENV
- name: Install and test mamba_ssm==${{ inputs.mamba_version }}
run: |
# Install mamba_ssm without build isolation to use existing torch from conda env
# No need to reinstall torch since it's already in the conda environment
pip install -v --no-build-isolation mamba_ssm==${{ inputs.mamba_version }}
- name: Verify mamba_ssm installation
run: |
python -c 'import mamba_ssm; print(f"mamba_ssm {mamba_ssm.__version__} installed successfully")'
31 changes: 31 additions & 0 deletions .github/workflows/installs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: "Test pip install - multiple versions"
on:
workflow_dispatch:
inputs:
python:
description: "Python version to use"
required: false
type: string
default: "3.11.13"
permissions:
id-token: write
contents: read
jobs:
installs:
name: Test mamba_ssm==${{ matrix.mamba_version }}
strategy:
matrix:
include:
# All versions support PyTorch 2.4, use AMI's PyTorch 2.4.1
- { "mamba_version": "2.2.0" }
- { "mamba_version": "2.2.1" }
- { "mamba_version": "2.2.2" }
- { "mamba_version": "2.2.3post2" }
- { "mamba_version": "2.2.4" }
- { "mamba_version": "2.2.5" }
fail-fast: false
uses: ./.github/workflows/install.yaml
secrets: inherit
with:
mamba_version: ${{ matrix.mamba_version }}
python: ${{ inputs.python }}
73 changes: 73 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: GPU tests
on:
workflow_dispatch:
inputs:
instance_type:
description: 'EC2 instance type'
required: false
type: choice
default: 'g6.2xlarge'
options:
- g5.xlarge # 4 vCPUs, 16GB RAM, A10G GPU, ≈$1.11/hr
- g5.2xlarge # 8 vCPUs, 32GB RAM, A10G GPU, ≈$1.33/hr
- g5.4xlarge # 16 vCPUs, 64GB RAM, A10G GPU, ≈$1.79/hr
- g6.xlarge # 4 vCPUs, 16GB RAM, L4 GPU, ≈$0.89/hr
- g6.2xlarge # 8 vCPUs, 32GB RAM, L4 GPU, ≈$1.08/hr
- g6.4xlarge # 16 vCPUs, 64GB RAM, L4 GPU, ≈$1.46/hr
workflow_call:
inputs:
instance_type:
description: 'EC2 instance type'
required: true
type: string
permissions:
id-token: write
contents: read
jobs:
ec2:
name: Start EC2 runner
uses: Open-Athena/ec2-gha/.github/workflows/runner.yml@v2
with:
ec2_instance_type: ${{ inputs.instance_type || 'g6.2xlarge' }}
ec2_image_id: ami-0aee7b90d684e107d # Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20250623
secrets:
GH_SA_TOKEN: ${{ secrets.GH_SA_TOKEN }}
test:
name: GPU tests
needs: ec2
runs-on: ${{ needs.ec2.outputs.id }}
steps:
- uses: actions/checkout@v4
- name: Setup Python environment
run: |
# Use the DLAMI's pre-installed PyTorch conda environment
echo "/opt/conda/envs/pytorch/bin" >> $GITHUB_PATH
echo "CONDA_DEFAULT_ENV=pytorch" >> $GITHUB_ENV
- name: Check GPU
run: nvidia-smi
- name: Install mamba-ssm and test dependencies
run: |
# Use all available CPUs for compilation (we're only building for 1 GPU arch)
export MAX_JOBS=$(nproc)

INSTANCE_TYPE="${{ inputs.instance_type || 'g6.2xlarge' }}"

# Set CUDA architecture based on GPU type
# TORCH_CUDA_ARCH_LIST tells PyTorch which specific architecture to compile for
if [[ "$INSTANCE_TYPE" == g5.* ]]; then
export TORCH_CUDA_ARCH_LIST="8.6" # A10G GPU
export CUDA_VISIBLE_DEVICES=0
export NVCC_GENCODE="-gencode arch=compute_86,code=sm_86"
elif [[ "$INSTANCE_TYPE" == g6.* ]]; then
export TORCH_CUDA_ARCH_LIST="8.9" # L4 GPU (Ada Lovelace)
export CUDA_VISIBLE_DEVICES=0
export NVCC_GENCODE="-gencode arch=compute_89,code=sm_89"
fi

echo "Building with MAX_JOBS=$MAX_JOBS for $INSTANCE_TYPE"

# Install mamba-ssm with causal-conv1d and dev dependencies
# Note: causal-conv1d will download pre-built wheels when available
pip install -v --no-build-isolation -e .[causal-conv1d,dev]
- name: Run tests
run: pytest -vs --maxfail=10 tests/
26 changes: 26 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: GPU tests on multiple instance types
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch:

permissions:
id-token: write
contents: read

jobs:
test-g5:
name: Test on g5.2xlarge (A10G)
uses: ./.github/workflows/test.yaml
with:
instance_type: g5.2xlarge
secrets: inherit

test-g6:
name: Test on g6.2xlarge (L4)
uses: ./.github/workflows/test.yaml
with:
instance_type: g6.2xlarge
secrets: inherit
48 changes: 31 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,39 @@ def append_nvcc_threads(nvcc_extra_args):
"Note: make sure nvcc has a supported version by running nvcc -V."
)

cc_flag.append("-gencode")
cc_flag.append("arch=compute_53,code=sm_53")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_62,code=sm_62")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")

if bare_metal_version >= Version("11.8"):
# Check for TORCH_CUDA_ARCH_LIST environment variable (for CI/testing)
# Format: "7.5" or "7.5;8.6" or "7.5 8.6"
cuda_arch_list = os.getenv("TORCH_CUDA_ARCH_LIST", "").replace(";", " ").split()

if cuda_arch_list:
# Use only the specified architectures
print(f"Building for specific CUDA architectures: {cuda_arch_list}")
for arch in cuda_arch_list:
arch_num = arch.replace(".", "")
cc_flag.append("-gencode")
cc_flag.append(f"arch=compute_{arch_num},code=sm_{arch_num}")
else:
# Default: build for all supported architectures
print("Building for all supported CUDA architectures (set TORCH_CUDA_ARCH_LIST to override)")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
if bare_metal_version >= Version("12.8"):
cc_flag.append("arch=compute_53,code=sm_53")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100")
cc_flag.append("arch=compute_62,code=sm_62")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")

if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
if bare_metal_version >= Version("12.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100")


# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
Expand Down
4 changes: 1 addition & 3 deletions tests/ops/triton/test_selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 6e-2, 6e-2
if torch.version.hip:
atol *= 2
rtol, atol = 9e-2, 9.6e-2
# set seed
torch.random.manual_seed(0)
batch_size = 16
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/triton/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def detach_clone(*args):
def test_chunk_state_varlen(chunk_size, ngroups, dtype):
device = 'cuda'
rtol, atol = (1e-2, 3e-3)
if dtype == torch.bfloat16:
rtol, atol = 6e-2, 6e-2
# set seed
torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64))
batch = 300
Expand Down
Loading