""" Manual LoRA merging script that handles key naming issues """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from safetensors.torch import load_file, save_file import os import argparse from tqdm import tqdm def merge_lora_weights( base_model_name, adapter_path, output_path, device_map="auto" ): """Manually merge LoRA weights into base model""" print(f"Loading base model: {base_model_name}") model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.bfloat16, device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, ) print(f"Loading LoRA adapters from: {adapter_path}") adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors")) print(f"Loaded {len(adapter_weights)} adapter weights") # Load adapter config to get scaling factor import json with open(os.path.join(adapter_path, "adapter_config.json")) as f: adapter_config = json.load(f) lora_alpha = adapter_config["lora_alpha"] r = adapter_config["r"] scaling = lora_alpha / r print(f"LoRA scaling factor: {scaling} (alpha={lora_alpha}, r={r})") # Group LoRA weights by layer lora_pairs = {} for key in adapter_weights.keys(): if "lora_A" in key: base_key = key.replace(".lora_A.weight", "") lora_pairs[base_key] = { "A": adapter_weights[key], "B": adapter_weights.get(base_key + ".lora_B.weight") } print(f"Found {len(lora_pairs)} LoRA pairs to merge") # Get model state dict model_state_dict = model.state_dict() # Map adapter keys to model keys # Adapter keys: base_model.model.model.layers.X.self_attn.q_proj # Model keys might be: model.layers.X.self_attn.q_proj (depending on device_map) print("\nMerging LoRA weights...") merged_count = 0 for adapter_key, lora_weights in tqdm(lora_pairs.items()): # Remove 'base_model.model.' prefix from adapter key # adapter_key looks like: base_model.model.model.layers.0.self_attn.q_proj if adapter_key.startswith("base_model.model."): model_key = adapter_key[len("base_model.model."):] else: model_key = adapter_key # Try to find the matching key in model found = False for mk in model_state_dict.keys(): if model_key in mk or mk.endswith(model_key): model_key = mk found = True break if not found: # Try alternative key formats alternatives = [ model_key, "model." + model_key, model_key.replace("model.", ""), ] for alt_key in alternatives: if alt_key in model_state_dict: model_key = alt_key found = True break if found and model_key in model_state_dict: # Merge: W' = W + (B @ A) * scaling lora_A = lora_weights["A"] lora_B = lora_weights["B"] # Move to same device as model weight device = model_state_dict[model_key].device lora_A = lora_A.to(device) lora_B = lora_B.to(device) # Compute delta_W = (lora_B @ lora_A) * scaling delta_W = (lora_B @ lora_A) * scaling # Add to original weight model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype) merged_count += 1 else: print(f"Warning: Could not find model key for {adapter_key}") print(f"\nSuccessfully merged {merged_count}/{len(lora_pairs)} LoRA weights") # Load merged weights back into model model.load_state_dict(model_state_dict, strict=False) # Save merged model print(f"\nSaving merged model to: {output_path}") os.makedirs(output_path, exist_ok=True) model.save_pretrained(output_path, safe_serialization=True, max_shard_size="5GB") # Also save tokenizer print("Saving tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) tokenizer.save_pretrained(output_path) print("\n✅ Merge complete!") return model if __name__ == "__main__": # For use in the Space BASE_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct" ADAPTER_PATH = "/app/lora_adapters" # We'll download here OUTPUT_PATH = "/app/merged_model" merge_lora_weights(BASE_MODEL, ADAPTER_PATH, OUTPUT_PATH)