diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh index 71307ca086e..17e42988c4d 100644 --- a/.ci/scripts/test_lora.sh +++ b/.ci/scripts/test_lora.sh @@ -41,12 +41,14 @@ HF_ADAPTER_PATH=$( --files "adapter_config.json" "adapter_model.safetensors" ) +# Set environment variables for OmegaConf interpolation in yaml. +export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors" +export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json" + ### SINGLE LORA PTE ### # Export LoRA PTE file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math_full.pte" # Capture the path of the downloaded qwen artifacts @@ -93,9 +95,7 @@ fi ### PROGRAM DATA SEPARATION ### # Export LoRA PTE, LoRA PTD, foundation PTD file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math.pte" \ +export.foundation_weights_file="qwen_foundation.ptd" \ +export.lora_weights_file="qwen_lora_math.ptd" @@ -108,7 +108,7 @@ cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math.pte --dat NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" -RESULT=$(cat result.txt) +RESULT=$(cat result2.txt) if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT}" @@ -143,8 +143,11 @@ So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12. The answer is: 12<|im_end|>" # Export Quantized PTE, PTD file, no LoRA. +# override base.lora_config=null to avoid creating a lora model +# and loading lora weights. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ + base.lora_config=null \ +export.output_name="qwen_q.pte" \ +export.foundation_weights_file="qwen_foundation_q.ptd" \ +quantization.qmode="8da4w" \ @@ -152,9 +155,7 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ # Export Quantized LoRA PTE, LoRA PTD, foundation PTD file. $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ - +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ - +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + --config examples/models/qwen3/config/qwen3_xnnpack_lora.yaml \ +export.output_name="qwen_lora_math_q.pte" \ +export.foundation_weights_file="qwen_foundation_lora_q.ptd" \ +export.lora_weights_file="qwen_lora_math_q.ptd" \ diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..8b35d7d3155 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): checkpoint_path = self.llm_config.base.checkpoint params_path = self.llm_config.base.params - # Adapter checkpoint and config. - adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint - adapter_config_path = self.llm_config.base.adapter_config - assert (adapter_checkpoint_path is None and adapter_config_path is None) or ( - adapter_checkpoint_path is not None and adapter_config_path is not None - ), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified." + # LoRA adapter configuration. + lora_config = self.llm_config.base.lora_config self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache @@ -69,10 +65,18 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): with open(params_path, "r") as f: params = json.loads(f.read()) - # Get adapter checkpoint and config. + # Get adapter checkpoint. adapter_checkpoint = {} - adapter_config = {} - if adapter_checkpoint_path: + if lora_config: + # Resolve LoRA params from adapter_config JSON if not already set. + if lora_config.adapter_config and lora_config.lora_rank == 0: + with open(lora_config.adapter_config, "r") as f: + cfg = json.load(f) + lora_config.lora_rank = cfg["r"] + lora_config.lora_alpha = cfg["lora_alpha"] + lora_config.target_modules = cfg["target_modules"] + + adapter_checkpoint_path = lora_config.adapter_checkpoint if adapter_checkpoint_path.endswith(".pt"): adapter_checkpoint = torch.load( adapter_checkpoint_path, map_location=device, mmap=True @@ -92,22 +96,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): raise ValueError( f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}" ) - - with open(adapter_config_path, "r") as f: - adapter_config_full = json.loads(f.read()) - if ( - "r" not in adapter_config_full - or "lora_alpha" not in adapter_config_full - or "target_modules" not in adapter_config_full - ): - raise ValueError( - "Adapter config must contain r, lora_alpha, and target_modules." - ) - adapter_config = { - "r": adapter_config_full["r"], - "lora_alpha": adapter_config_full["lora_alpha"], - "target_modules": adapter_config_full["target_modules"], - } checkpoint.update(adapter_checkpoint) output_prune_map = None @@ -133,8 +121,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, + r=lora_config.lora_rank if lora_config else None, + lora_alpha=lora_config.lora_alpha if lora_config else None, + target_modules=lora_config.target_modules if lora_config else None, **params, - **adapter_config, ) if model_args.use_scaled_rope: @@ -356,9 +346,10 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): embedding_bit_width, embedding_group_size = None, None if self.llm_config.base.preq_embedding_quantize: - embedding_bit_width, embedding_group_size = ( - self.llm_config.base.preq_embedding_quantize.split(",") - ) + ( + embedding_bit_width, + embedding_group_size, + ) = self.llm_config.base.preq_embedding_quantize.split(",") from .source_transformation.pre_quantization import ( transform_embedding_for_pre_quantization, ) diff --git a/examples/models/qwen3/config/qwen3_xnnpack.yaml b/examples/models/qwen3/config/qwen3_xnnpack_lora.yaml similarity index 74% rename from examples/models/qwen3/config/qwen3_xnnpack.yaml rename to examples/models/qwen3/config/qwen3_xnnpack_lora.yaml index 1c4801bf5ef..3836b7793fb 100644 --- a/examples/models/qwen3/config/qwen3_xnnpack.yaml +++ b/examples/models/qwen3/config/qwen3_xnnpack_lora.yaml @@ -2,6 +2,9 @@ base: model_class: "qwen3_0_6b" params: "examples/models/qwen3/config/0_6b_config.json" metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + lora_config: + adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT} + adapter_config: ${oc.env:LORA_ADAPTER_CONFIG} model: use_kv_cache: True diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index a7453fd09c1..db56b686ba5 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -62,6 +62,36 @@ class PreqMode(str, Enum): preq_8da4w_out_8da8w = "8da4w_output_8da8w" +@dataclass +class LoraConfig: + """LoRA adapter configuration. + + Can be created in two ways: + + 1. From an adapter_config JSON file: + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + adapter_config="/path/to/adapter_config.json", + ) + Note: user is responsible for parsing the config and + ensure it doesn't conflict with any explicit values. + + 2. With explicit values: + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + lora_rank=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + ) + """ + + adapter_checkpoint: str + adapter_config: Optional[str] = None + lora_rank: int = 0 + lora_alpha: int = 0 + target_modules: List[str] = field(default_factory=list) + + @dataclass class BaseConfig: """ @@ -77,11 +107,7 @@ class BaseConfig: If left empty, the model will either be initialized with random weights if it is a Llama model or the weights will be downloaded from HuggingFace if it is a non-Llama model. - adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if - the model has trained LoRA adapters. Must provide - adapter_config.json. - adapter_config: Path to the adapter_config.json file from torchtune. - Used if the model has trained LoRA adapters. Must provide adapter.pt. + lora_config: LoRA adapter configuration. tokenizer_path: Path to the tokenizer file. metadata: Json string containing metadata information. e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' @@ -98,8 +124,7 @@ class BaseConfig: model_class: ModelType = ModelType.llama3 params: Optional[str] = None checkpoint: Optional[str] = None - adapter_checkpoint: Optional[str] = None - adapter_config: Optional[str] = None + lora_config: Optional[LoraConfig] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None use_lora: int = 0 @@ -536,10 +561,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.base.params = args.params if hasattr(args, "checkpoint"): llm_config.base.checkpoint = args.checkpoint - if hasattr(args, "adapter_checkpoint"): - llm_config.base.adapter_checkpoint = args.adapter_checkpoint - if hasattr(args, "adapter_config"): - llm_config.base.adapter_config = args.adapter_config + if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint: + if not hasattr(args, "adapter_config") or not args.adapter_config: + raise ValueError("--adapter_checkpoint requires --adapter_config") + llm_config.base.lora_config = LoraConfig( + adapter_checkpoint=args.adapter_checkpoint, + adapter_config=args.adapter_config, + ) if hasattr(args, "tokenizer_path"): llm_config.base.tokenizer_path = args.tokenizer_path if hasattr(args, "metadata"):