MiniLM / API_USAGE.md
0sparsh2's picture
Upload API_USAGE.md with huggingface_hub
ac4bbf5 verified

API & Inference Usage

This guide covers how to load the MiniLM 1.58-bit base model and dynamically snap on custom LoRAs for inference.

Python Inference (PyTorch)

Because MiniLM uses custom ternary BitLinear layers, it cannot be loaded via the standard transformers AutoModel pipeline. You must use the provided model.py and lora.py scripts.

1. Loading the Base Model

import torch
from transformers import AutoTokenizer
from model import BitGPT

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")

# Initialize the 12-Layer Tied Architecture
model = BitGPT(vocab_size=len(tokenizer), embed_dim=256, num_layers=12, num_heads=4, tie_weights=True).to(device)

# Load the frozen 1.58-bit Base Weights
model.load_state_dict(torch.load("minilm_base.pt", map_location=device))
model.eval()

2. Injecting a "Side-Car" LoRA

If you want to run a specific task (like Smart Home JSON extraction), you must wrap the Linear layers with the custom BitLoraLinear adapter.

from lora import inject_lora

# Wrap the model's layers with LoRA adapters
model = inject_lora(model, r=8, lora_alpha=16).to(device)

# Snap on the custom 1MB weights (strict=False ensures we only overwrite the new LoRA parameters)
model.load_state_dict(torch.load("lora_smarthome.pt", map_location=device), strict=False)
model.eval()

3. Generation Loop

To generate text, format your prompt using ChatML standard tags:

prompt = "Uh, it's freezing in here, can you turn up the heat in the living room?"
chatml_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(chatml_text, return_tensors="pt").to(device)

max_new_tokens = 60
with torch.no_grad():
    for _ in range(max_new_tokens):
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :]
        
        # Greedy decoding
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        # Stop condition (2 is im_end in ChatML)
        if next_token.item() == tokenizer.eos_token_id or next_token.item() == 2:
            break
            
output_text = tokenizer.decode(input_ids[0])
final_output = output_text.split("<|im_start|>assistant\n")[-1].replace("<|im_end|>", "").strip()

print(final_output)
# Output: {"device": "thermostat", "action": "increase_temp", "room": "living_room"}