-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fdd2de2
commit 50a0648
Showing
3 changed files
with
52 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ on: | |
- you-pypi-packaging-1 # Runs only when pushing to main | ||
paths: | ||
- "csrc/sliding_tile_attention/setup.py" | ||
- ".github/workflows/sta-publish.yml" | ||
|
||
jobs: | ||
check-version-change: | ||
|
@@ -29,7 +28,7 @@ jobs: | |
echo "New version: $NEW_VERSION" | ||
# Get previous version from git history | ||
OLD_VERSION=$(git show HEAD~1:setup.py | grep -oP 'VERSION\s*=\s*"\K[^"]+' || echo "0.0.0") | ||
OLD_VERSION=$(git show HEAD~1:./setup.py | grep -oP 'VERSION\s*=\s*"\K[^"]+' || echo "0.0.0") | ||
echo "Old version: $OLD_VERSION" | ||
if [ "$NEW_VERSION" != "$OLD_VERSION" ]; then | ||
|
@@ -40,14 +39,13 @@ jobs: | |
echo "Version did not change" | ||
echo "changed=false" >> $GITHUB_OUTPUT | ||
fi | ||
publish_package: | ||
name: Publish package | ||
needs: check-version-change | ||
# needs: [build_wheels] | ||
if: needs.check-version-change.outputs.version-changed == 'true' | ||
runs-on: ubuntu-22.04 | ||
container: | ||
image: nvidia/cuda:12.4.1-devel-ubuntu22.04 # Use CUDA 12.4 Docker image | ||
permissions: | ||
id-token: write # Needed for OIDC Trusted Publishing | ||
|
||
|
@@ -57,23 +55,63 @@ jobs: | |
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.10' | ||
|
||
- name: Install CUDA 12.4.1 | ||
uses: Jimver/[email protected] | ||
id: cuda-toolkit | ||
with: | ||
cuda: 12.4.1 | ||
linux-local-args: '["--toolkit"]' | ||
method: 'network' | ||
sub-packages: '["nvcc"]' | ||
|
||
- name: Install dependencies | ||
- name: Install dependencies (GCC, Clang, CUDA Paths, Git) | ||
run: | | ||
python -m pip install --upgrade pip ninja packaging wheel twine | ||
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) | ||
pip install setuptools==75.8.0 | ||
# We don't want to download anything CUDA-related here | ||
pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
sudo apt update | ||
sudo apt install -y git patchelf gcc-11 g++-11 clang-11 | ||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11 | ||
# Allow Git to Access Safe Directory | ||
git config --global --add safe.directory /__w/FastVideo/FastVideo | ||
# Set CUDA environment variables | ||
export CUDA_HOME=/usr/local/cuda-12.4 | ||
export CUDA_HOME=/usr/local/cuda-12.4.1 | ||
export PATH=${CUDA_HOME}/bin:${PATH} | ||
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH | ||
# Verify installation | ||
gcc --version | ||
g++ --version | ||
clang-11 --version | ||
nvcc --version | ||
- name: Install PyTorch 2.5.1+cu12.4.1 | ||
run: | | ||
pip install --upgrade pip | ||
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error | ||
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable | ||
pip install typing-extensions==4.12.2 | ||
# We want to figure out the CUDA version to download pytorch | ||
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 | ||
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix | ||
export TORCH_CUDA_VERSION=124 | ||
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} | ||
nvcc --version | ||
python --version | ||
python -c "import torch; print('PyTorch:', torch.__version__)" | ||
python -c "import torch; print('CUDA:', torch.version.cuda)" | ||
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" | ||
- name: Build source distribution | ||
run: | | ||
cd csrc/sliding_tile_attention | ||
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 | ||
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 | ||
# However this still fails so I'm using a newer version of setuptools | ||
pip install setuptools | ||
pip install ninja packaging wheel | ||
cd csrc/sliding_tile_attention # Move into the correct folder | ||
git submodule update --init --recursive tk # Ensure ThunderKittens submodule is initialized | ||
python setup.py sdist --dist-dir=dist | ||
- name: Publish release distributions to PyPI | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters