feat: Add Ascend npu support.#576
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Ascend NPU devices across the codebase. It introduces NPU installation documentation, a requirements file, an example training script, and NPU-specific Triton kernels for log-softmax loss. Additionally, it replaces hardcoded CUDA references with dynamic device detection and helper functions. The feedback highlights critical issues where the new device helper functions and distributed initialization will crash on CPU-only environments due to invalid module lookups and potential division-by-zero errors. It also suggests quoting variables in the new shell script to prevent path-splitting issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def get_device_module(): | ||
| """Return the torch device module for the current device type (cuda, npu, etc.).""" | ||
| return torch.get_device_module(get_device_type()) | ||
|
|
||
|
|
||
| def empty_cache(): | ||
| """Empty the cache for the current accelerator device.""" | ||
| get_device_module().empty_cache() | ||
|
|
||
|
|
||
| def device_count() -> int: | ||
| """Return the number of available accelerator devices.""" | ||
| return get_device_module().device_count() | ||
|
|
||
|
|
||
| def current_device() -> int: | ||
| """Return the current accelerator device index.""" | ||
| return get_device_module().current_device() | ||
|
|
||
|
|
||
| def set_device(device_id: int): | ||
| """Set the current accelerator device.""" | ||
| get_device_module().set_device(device_id) | ||
|
|
||
|
|
||
| def synchronize(): | ||
| """Synchronize the current accelerator device.""" | ||
| get_device_module().synchronize() |
There was a problem hiding this comment.
The current implementation of get_device_module and other device helpers will fail on CPU-only environments because torch.get_device_module("cpu") is not valid or does not contain accelerator-specific methods like empty_cache or synchronize. To ensure robustness and prevent crashes when running on CPU, we should gracefully handle the "cpu" device type by returning None or acting as a no-op.
def get_device_module():
"""Return the torch device module for the current device type (cuda, npu, etc.)."""
device_type = get_device_type()
if device_type == "cpu":
return None
return torch.get_device_module(device_type)
def empty_cache():
"""Empty the cache for the current accelerator device."""
module = get_device_module()
if module is not None:
module.empty_cache()
def device_count() -> int:
"""Return the number of available accelerator devices."""
module = get_device_module()
if module is not None:
return module.device_count()
return 0
def current_device() -> int:
"""Return the current accelerator device index."""
module = get_device_module()
if module is not None:
return module.current_device()
return 0
def set_device(device_id: int):
"""Set the current accelerator device."""
module = get_device_module()
if module is not None:
module.set_device(device_id)
def synchronize():
"""Synchronize the current accelerator device."""
module = get_device_module()
if module is not None:
module.synchronize()| dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout)) | ||
| local_rank = dist.get_rank() % device_count() | ||
| set_device(local_rank) |
There was a problem hiding this comment.
If device_count() returns 0 (which is the case on CPU-only environments), dist.get_rank() % device_count() will raise a ZeroDivisionError. We should guard against this by checking if device_count() > 0 before performing the modulo operation.
| dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout)) | |
| local_rank = dist.get_rank() % device_count() | |
| set_device(local_rank) | |
| dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout)) | |
| local_rank = dist.get_rank() % device_count() if device_count() > 0 else 0 | |
| set_device(local_rank) |
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||
| ROOT_DIR=$(dirname $SCRIPT_DIR) | ||
|
|
||
| export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels | ||
| # train eagle3 for llama3.1-8b | ||
| NUM_GPUS=${1:-1} | ||
| TP_SIZE=${2:-1} | ||
| BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} | ||
|
|
||
| # Currently we only train with --max-length 2048 due to OOM issue on A3(64GB) | ||
| torchrun \ | ||
| --standalone \ | ||
| --nproc_per_node $NUM_GPUS \ | ||
| $ROOT_DIR/scripts/train_eagle3.py \ | ||
| --target-model-path meta-llama/Llama-3.1-8B-Instruct \ | ||
| --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ | ||
| --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ | ||
| --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ | ||
| --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt \ | ||
| --num-epochs 10 \ | ||
| --batch-size 1 \ | ||
| --tp-size $TP_SIZE \ | ||
| --learning-rate 1e-4 \ | ||
| --max-length 2048 \ | ||
| --chat-template llama3 \ | ||
| --cache-dir $ROOT_DIR/cache \ | ||
| --attention-backend sdpa \ | ||
| --target-model-backend sglang \ | ||
| --log-interval 10 \ | ||
| --sglang-mem-fraction-static 0.3 |
There was a problem hiding this comment.
Unquoted variables in shell scripts can cause word splitting and globbing issues if paths contain spaces or special characters. It is highly recommended to double-quote all variable expansions like $ROOT_DIR and $BUILD_DATASET_NUM_PROC to ensure robustness.
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname "$SCRIPT_DIR")
export TORCHINDUCTOR_CACHE_DIR="$ROOT_DIR/cache/compiled_kernels"
# train eagle3 for llama3.1-8b
NUM_GPUS=${1:-1}
TP_SIZE=${2:-1}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}
# Currently we only train with --max-length 2048 due to OOM issue on A3(64GB)
torchrun \
--standalone \
--nproc_per_node "$NUM_GPUS" \
"$ROOT_DIR/scripts/train_eagle3.py" \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--draft-model-config "$ROOT_DIR/configs/llama3-8B-eagle3.json" \
--train-data-path "$ROOT_DIR/cache/dataset/sharegpt_train.jsonl" \
--build-dataset-num-proc "$BUILD_DATASET_NUM_PROC" \
--output-dir "$ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt" \
--num-epochs 10 \
--batch-size 1 \
--tp-size "$TP_SIZE" \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama3 \
--cache-dir "$ROOT_DIR/cache" \
--attention-backend sdpa \
--target-model-backend sglang \
--log-interval 10 \
--sglang-mem-fraction-static 0.3|
Hi, can you help fix the conflicts? |
Signed-off-by: menogrey <1299267905@qq.com>
- Add get_device_type() to auto-detect npu/cuda/cpu via runtime check - Add get_local_device() to return torch.device for current LOCAL_RANK - Replace hardcoded .cuda() and device='cuda' in train_dflash.py with dynamic device selection - Use .to(device, non_blocking=True) for tensor movement to support both CUDA and Ascend NPU without code changes - Maintain backward compatibility: CUDA remains default when available
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
be025c8 to
10339f3
Compare
|
updated, please take a look @jiapingW |
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
|
Fixed lint CI error. @jiapingW anything else from my side that needs to be done for the PR to be merged? |
Motivation
This PR introduce Ascend NPU support for SpecForge, the majority of the changes replace existing CUDA interfaces with hardware-adaptive options, also includes Ascend NPU installation guides and examples.
Modifications
requirements-npu.txtandInstall on Ascend NPUsection oninstallation.md--max-lengthand--sglang-mem-fraction-staticdue to OOM issue.Related Issues
Fixes #557
Accuracy Test
This PR should not introduce any affect on CUDA hardware.
Train with

examples/run_llama3.1_8b_eagle3_online_npu.shTrain loss:
llama3.1-8b:
llama3.1-8b with eagle3-sharegpt:
Benchmark & Profiling
Checklist