Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions .ci/scripts/test_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ HF_ADAPTER_PATH=$(
--files "adapter_config.json" "adapter_model.safetensors"
)

# Set environment variables for OmegaConf interpolation in yaml.
export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors"
export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json"

### SINGLE LORA PTE ###
# Export LoRA PTE file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math_full.pte"

# Capture the path of the downloaded qwen artifacts
Expand Down Expand Up @@ -93,9 +95,7 @@ fi
### PROGRAM DATA SEPARATION ###
# Export LoRA PTE, LoRA PTD, foundation PTD file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math.pte" \
+export.foundation_weights_file="qwen_foundation.ptd" \
+export.lora_weights_file="qwen_lora_math.ptd"
Expand All @@ -108,7 +108,7 @@ cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math.pte --dat
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT=$(cat result.txt)
RESULT=$(cat result2.txt)
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT}"
Expand Down Expand Up @@ -143,18 +143,19 @@ So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
The answer is: 12<|im_end|>"

# Export Quantized PTE, PTD file, no LoRA.
# override base.lora_config=null to avoid creating a lora model
# and loading lora weights.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
base.lora_config=null \
+export.output_name="qwen_q.pte" \
+export.foundation_weights_file="qwen_foundation_q.ptd" \
+quantization.qmode="8da4w" \
+quantization.group_size=32

# Export Quantized LoRA PTE, LoRA PTD, foundation PTD file.
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_xnnpack.yaml \
+base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \
+base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \
--config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \
+export.output_name="qwen_lora_math_q.pte" \
+export.foundation_weights_file="qwen_foundation_lora_q.ptd" \
+export.lora_weights_file="qwen_lora_math_q.ptd" \
Expand Down
49 changes: 20 additions & 29 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
checkpoint_path = self.llm_config.base.checkpoint
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
adapter_config_path = self.llm_config.base.adapter_config
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
adapter_checkpoint_path is not None and adapter_config_path is not None
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."
# LoRA adapter configuration.
lora_config = self.llm_config.base.lora_config

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
Expand Down Expand Up @@ -69,10 +65,18 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
# Get adapter checkpoint.
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
if lora_config:
# Resolve LoRA params from adapter_config JSON if not already set.
if lora_config.adapter_config and lora_config.lora_rank == 0:
with open(lora_config.adapter_config, "r") as f:
cfg = json.load(f)
lora_config.lora_rank = cfg["r"]
lora_config.lora_alpha = cfg["lora_alpha"]
lora_config.target_modules = cfg["target_modules"]

adapter_checkpoint_path = lora_config.adapter_checkpoint
if adapter_checkpoint_path.endswith(".pt"):
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
Expand All @@ -92,22 +96,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
raise ValueError(
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
)

with open(adapter_config_path, "r") as f:
adapter_config_full = json.loads(f.read())
if (
"r" not in adapter_config_full
or "lora_alpha" not in adapter_config_full
or "target_modules" not in adapter_config_full
):
raise ValueError(
"Adapter config must contain r, lora_alpha, and target_modules."
)
adapter_config = {
"r": adapter_config_full["r"],
"lora_alpha": adapter_config_full["lora_alpha"],
"target_modules": adapter_config_full["target_modules"],
}
checkpoint.update(adapter_checkpoint)

output_prune_map = None
Expand All @@ -133,8 +121,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
input_prune_map=input_prune_map,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
r=lora_config.lora_rank if lora_config else None,
lora_alpha=lora_config.lora_alpha if lora_config else None,
target_modules=lora_config.target_modules if lora_config else None,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down Expand Up @@ -356,9 +346,10 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):

embedding_bit_width, embedding_group_size = None, None
if self.llm_config.base.preq_embedding_quantize:
embedding_bit_width, embedding_group_size = (
self.llm_config.base.preq_embedding_quantize.split(",")
)
(
embedding_bit_width,
embedding_group_size,
) = self.llm_config.base.preq_embedding_quantize.split(",")
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ base:
model_class: "qwen3_0_6b"
params: "examples/models/qwen3/config/0_6b_config.json"
metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}'
lora_config:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}

model:
use_kv_cache: True
Expand Down
50 changes: 39 additions & 11 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@ class PreqMode(str, Enum):
preq_8da4w_out_8da8w = "8da4w_output_8da8w"


@dataclass
class LoraConfig:
"""LoRA adapter configuration.

Can be created in two ways:

1. From an adapter_config JSON file:
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
adapter_config="/path/to/adapter_config.json",
)
Note: user is responsible for parsing the config and
ensure it doesn't conflict with any explicit values.

2. With explicit values:
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
lora_rank=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
)
"""

adapter_checkpoint: str
adapter_config: Optional[str] = None
lora_rank: int = 0
lora_alpha: int = 0
target_modules: List[str] = field(default_factory=list)


@dataclass
class BaseConfig:
"""
Expand All @@ -77,11 +107,7 @@ class BaseConfig:
If left empty, the model will either be initialized with random weights
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
adapter_config: Path to the adapter_config.json file from torchtune.
Used if the model has trained LoRA adapters. Must provide adapter.pt.
lora_config: LoRA adapter configuration.
tokenizer_path: Path to the tokenizer file.
metadata: Json string containing metadata information.
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
Expand All @@ -98,8 +124,7 @@ class BaseConfig:
model_class: ModelType = ModelType.llama3
params: Optional[str] = None
checkpoint: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
lora_config: Optional[LoraConfig] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
Expand Down Expand Up @@ -536,10 +561,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
llm_config.base.checkpoint = args.checkpoint
if hasattr(args, "adapter_checkpoint"):
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
if hasattr(args, "adapter_config"):
llm_config.base.adapter_config = args.adapter_config
if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint:
if not hasattr(args, "adapter_config") or not args.adapter_config:
raise ValueError("--adapter_checkpoint requires --adapter_config")
llm_config.base.lora_config = LoraConfig(
adapter_checkpoint=args.adapter_checkpoint,
adapter_config=args.adapter_config,
)
if hasattr(args, "tokenizer_path"):
llm_config.base.tokenizer_path = args.tokenizer_path
if hasattr(args, "metadata"):
Expand Down
Loading