Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export PRETRAIN_GDRCOPY_VERSION=2.4.1
export PRETRAIN_PYTHON_VERSION=3.10.4
export PRETRAIN_TORCH_VERSION=2.8.0
export PRETRAIN_APEX_COMMIT=e13873debc4699d39c6861074b9a3b2a02327f92
export PRETRAIN_FLASH_ATTENTION_VERSION=060c9188beec3a8b62b33a3bfa6d5d2d44975fab
export PRETRAIN_TRANSFORMER_ENGINE_VERSION=2.5.0
export PRETRAIN_FLASH_ATTENTION_VERSION=2.8.1
export PRETRAIN_TRANSFORMER_ENGINE_VERSION=2.8.0
export PRETRAIN_NVSHMEM_VERSION=3.4.5
export PRETRAIN_DEEPEP_VERSION=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
# Installs flash attention 3 (flash attention for NVIDIA Hopper architecture).

# CAUTION(sosuke):
# Installing flash attention v2 and v3 in the same environment may cause problems when used with Megatron-LM.
# We highly recommend only to use flash attention v3 for Hopper architecture.

echo "Installing Flash Attention ${PRETRAIN_FLASH_ATTENTION_VERSION}"
source ${TARGET_DIR}/venv/bin/activate

pushd ${TARGET_DIR}/src

git clone https://github.com/Dao-AILab/flash-attention.git
pushd flash-attention/
git checkout ${PRETRAIN_FLASH_ATTENTION_VERSION}

# Use flash-attention 3
pushd hopper/
git clone https://github.com/Dao-AILab/flash-attention.git -b v${PRETRAIN_FLASH_ATTENTION_VERSION} --recursive

pushd flash-attention
# install v2
python setup.py install
pushd hopper
# install v3
python setup.py install
python_path=$(python -c "import site; print(site.getsitepackages()[0])")
cp ./flash_attn_interface.py ${python_path}/flash_attn_3
popd
popd

python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/${PRETRAIN_FLASH_ATTENTION_VERSION}/hopper/flash_attn_interface.py

popd # hopper/
popd # flash-attention/
popd # ${TARGET_DIR}/src

deactivate
3 changes: 2 additions & 1 deletion pretrain/scripts/v4-8b-phase1/base/params.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,5 @@ ALL_PARAMS+=(
# NOTE(odashi):
# https://docs.nvidia.com/nemo-framework/user-guide/latest/performance/performance-guide.html#communication-overlaps-and-tuning
export NVTE_FWD_LAYERNORM_SM_MARGIN=16
export NVTE_BWD_LAYERNORM_SM_MARGIN=16
export NVTE_BWD_LAYERNORM_SM_MARGIN=16