Skip to content

feat: Add Ascend npu support.#576

Open
menogrey wants to merge 10 commits into
sgl-project:mainfrom
menogrey:add_npu_support
Open

feat: Add Ascend npu support.#576
menogrey wants to merge 10 commits into
sgl-project:mainfrom
menogrey:add_npu_support

Conversation

@menogrey

@menogrey menogrey commented Jun 9, 2026

Copy link
Copy Markdown

Motivation

This PR introduce Ascend NPU support for SpecForge, the majority of the changes replace existing CUDA interfaces with hardware-adaptive options, also includes Ascend NPU installation guides and examples.

Modifications

  • Existing CUDA interface: .npu(), to(device="CUDA"), device_count, empty_cache, current_device, set_device, synchronize, etc.
  • torch.compile: Currently Ascend NPU cannot use inductor backend, so fallback to the eager backend.
  • distributed: Ascend use hccl backend.
  • Ascend triton log softmax implement to save the memory: Cause the Ascend triton-ascend cannot use the CUDA triton implement due to UB overflow issue.
  • Ascend NPU installation guide: Add requirements-npu.txt and Install on Ascend NPU section on installation.md
  • A simple example for llama3.1 online train: Change the --max-length and --sglang-mem-fraction-static due to OOM issue.
  • Directly drop the conversation if not matching the parse rule. Fixes [Bug] Found an error when running the example. #557

Related Issues

Fixes #557

Accuracy Test

This PR should not introduce any affect on CUDA hardware.

Train with examples/run_llama3.1_8b_eagle3_online_npu.sh
Train loss:
image
llama3.1-8b:

  • humaneval output throughput: 72.55
  • math500 output throughput: 72.30
    llama3.1-8b with eagle3-sharegpt:
  • humaneval output throughput: 115.24 ~1.58x accept length: 2.60
  • math500 output throughput: 111.29 ~1.53x accept length: 2.52

Benchmark & Profiling

Checklist

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Ascend NPU devices across the codebase. It introduces NPU installation documentation, a requirements file, an example training script, and NPU-specific Triton kernels for log-softmax loss. Additionally, it replaces hardcoded CUDA references with dynamic device detection and helper functions. The feedback highlights critical issues where the new device helper functions and distributed initialization will crash on CPU-only environments due to invalid module lookups and potential division-by-zero errors. It also suggests quoting variables in the new shell script to prevent path-splitting issues.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread specforge/utils.py
Comment on lines +87 to +114
def get_device_module():
"""Return the torch device module for the current device type (cuda, npu, etc.)."""
return torch.get_device_module(get_device_type())


def empty_cache():
"""Empty the cache for the current accelerator device."""
get_device_module().empty_cache()


def device_count() -> int:
"""Return the number of available accelerator devices."""
return get_device_module().device_count()


def current_device() -> int:
"""Return the current accelerator device index."""
return get_device_module().current_device()


def set_device(device_id: int):
"""Set the current accelerator device."""
get_device_module().set_device(device_id)


def synchronize():
"""Synchronize the current accelerator device."""
get_device_module().synchronize()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of get_device_module and other device helpers will fail on CPU-only environments because torch.get_device_module("cpu") is not valid or does not contain accelerator-specific methods like empty_cache or synchronize. To ensure robustness and prevent crashes when running on CPU, we should gracefully handle the "cpu" device type by returning None or acting as a no-op.

def get_device_module():
    """Return the torch device module for the current device type (cuda, npu, etc.)."""
    device_type = get_device_type()
    if device_type == "cpu":
        return None
    return torch.get_device_module(device_type)


def empty_cache():
    """Empty the cache for the current accelerator device."""
    module = get_device_module()
    if module is not None:
        module.empty_cache()


def device_count() -> int:
    """Return the number of available accelerator devices."""
    module = get_device_module()
    if module is not None:
        return module.device_count()
    return 0


def current_device() -> int:
    """Return the current accelerator device index."""
    module = get_device_module()
    if module is not None:
        return module.current_device()
    return 0


def set_device(device_id: int):
    """Set the current accelerator device."""
    module = get_device_module()
    if module is not None:
        module.set_device(device_id)


def synchronize():
    """Synchronize the current accelerator device."""
    module = get_device_module()
    if module is not None:
        module.synchronize()

Comment thread specforge/distributed.py Outdated
Comment on lines +81 to +83
dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout))
local_rank = dist.get_rank() % device_count()
set_device(local_rank)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If device_count() returns 0 (which is the case on CPU-only environments), dist.get_rank() % device_count() will raise a ZeroDivisionError. We should guard against this by checking if device_count() > 0 before performing the modulo operation.

Suggested change
dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout))
local_rank = dist.get_rank() % device_count()
set_device(local_rank)
dist.init_process_group(backend=get_dist_backend(), timeout=timedelta(minutes=timeout))
local_rank = dist.get_rank() % device_count() if device_count() > 0 else 0
set_device(local_rank)

Comment on lines +1 to +30
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
# train eagle3 for llama3.1-8b
NUM_GPUS=${1:-1}
TP_SIZE=${2:-1}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}

# Currently we only train with --max-length 2048 due to OOM issue on A3(64GB)
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
--output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt \
--num-epochs 10 \
--batch-size 1 \
--tp-size $TP_SIZE \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama3 \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend sglang \
--log-interval 10 \
--sglang-mem-fraction-static 0.3

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Unquoted variables in shell scripts can cause word splitting and globbing issues if paths contain spaces or special characters. It is highly recommended to double-quote all variable expansions like $ROOT_DIR and $BUILD_DATASET_NUM_PROC to ensure robustness.

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname "$SCRIPT_DIR")

export TORCHINDUCTOR_CACHE_DIR="$ROOT_DIR/cache/compiled_kernels"
# train eagle3 for llama3.1-8b
NUM_GPUS=${1:-1}
TP_SIZE=${2:-1}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}

# Currently we only train with --max-length 2048 due to OOM issue on A3(64GB)
torchrun \
    --standalone \
    --nproc_per_node "$NUM_GPUS" \
    "$ROOT_DIR/scripts/train_eagle3.py" \
    --target-model-path meta-llama/Llama-3.1-8B-Instruct \
    --draft-model-config "$ROOT_DIR/configs/llama3-8B-eagle3.json" \
    --train-data-path "$ROOT_DIR/cache/dataset/sharegpt_train.jsonl" \
    --build-dataset-num-proc "$BUILD_DATASET_NUM_PROC" \
    --output-dir "$ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt" \
    --num-epochs 10 \
    --batch-size 1 \
    --tp-size "$TP_SIZE" \
    --learning-rate 1e-4 \
    --max-length 2048 \
    --chat-template llama3 \
    --cache-dir "$ROOT_DIR/cache" \
    --attention-backend sdpa \
    --target-model-backend sglang \
    --log-interval 10 \
    --sglang-mem-fraction-static 0.3

@jiapingW

Copy link
Copy Markdown
Collaborator

Hi, can you help fix the conflicts?

menogrey and others added 8 commits June 12, 2026 10:18
Signed-off-by: menogrey <1299267905@qq.com>
- Add get_device_type() to auto-detect npu/cuda/cpu via runtime check
- Add get_local_device() to return torch.device for current LOCAL_RANK
- Replace hardcoded .cuda() and device='cuda' in train_dflash.py with
  dynamic device selection
- Use .to(device, non_blocking=True) for tensor movement to support
  both CUDA and Ascend NPU without code changes
- Maintain backward compatibility: CUDA remains default when available
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
@menogrey

Copy link
Copy Markdown
Author

updated, please take a look @jiapingW

menogrey added 2 commits June 12, 2026 10:51
Signed-off-by: menogrey <1299267905@qq.com>
Signed-off-by: menogrey <1299267905@qq.com>
@menogrey

Copy link
Copy Markdown
Author

Fixed lint CI error. @jiapingW anything else from my side that needs to be done for the PR to be merged?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Found an error when running the example.

3 participants