MiniLM β€” BitNet 1.58b Sparse 2:4 Instruct (5 MB)

MiniLM is an ultra-compressed 1.58-bit ternary sparse language model trained via knowledge distillation from HuggingFaceTB/SmolLM-135M-Instruct. It implements the BitNet (1.58b) architecture with Sparse 2:4 structured pruning β€” meaning at least 50% of every block of 4 weights in each linear layer is forced to zero, then healed back with full Alpaca instruction fine-tuning.

The result is a ~5 MB effective model (at true 1.58-bit packing) that runs entirely on-device β€” no cloud, no API, no GPU required.


πŸ”₯ Highlights

  • 25.7M parameters β€” 5Γ— smaller than the 135M teacher, yet instruction-aware
  • Sparse 2:4 structure β€” 24.5% of all weights are exactly zero, with at least 2 zeros per every group of 4
  • 1.58-bit quantisation β€” internal linear layers use ternary weights {-1, 0, +1}
  • Knowledge distillation β€” trained with KL divergence against SmolLM-135M-Instruct soft targets
  • Instruct fine-tuned β€” trained on the full Alpaca instruction dataset (52K examples) in ChatML format
  • 15,000 training steps β€” on Apple MPS (Metal Performance Shaders)
  • Best validation CE loss: 2.5907 vs teacher baseline of 1.85

πŸ“ Architecture

Property Value
Architecture BitNet 1.58b (ternary linear layers)
Layers 12 transformer blocks
Embedding dim 256
Attention heads 4
FFN hidden dim 1024 (SwiGLU)
Position embeddings Learned, 2048 positions
Norm LayerNorm (post-attention)
Weight tying Yes (embedding ↔ output head)
Sparsity 24.5% zero weights (Sparse 2:4 structure)
Parameters 25,696,768
Theoretical 1.58-bit size ~5.08 MB
File size on disk (fp32) 98 MB
Tokenizer HuggingFaceTB/SmolLM-135M-Instruct (49,152 vocab)

BitLinear Quantisation

Every nn.Linear layer is replaced with a custom BitLinear that:

  1. Quantises weights to ternary {-1, 0, +1} via round(W / mean|W|).clamp(-1, 1)
  2. Quantises activations to 8-bit integers per token
  3. Dequantises the output using stored float scales

This happens transparently at inference β€” the stored weights are float32, but the effective compute is ternary Γ— int8.


πŸ‹οΈ Training Details

Property Value
Teacher model HuggingFaceTB/SmolLM-135M-Instruct (135M params)
Training dataset tatsu-lab/alpaca (52K instruction pairs)
Training format ChatML (<|im_start|>user … <|im_end|>)
Sequence length 128 tokens (boundary-padded)
Batch size 8
Steps 15,000
Optimizer AdamW (lr=1e-3, weight_decay=0.01)
KD temperature T=2
KD alpha Ξ±=0.5 (equal CE + KL)
Sparse masking Backward hooks freeze zero-weight gradients
Hardware Apple M-series MPS (on-device)

Training Objective

Loss = 0.5 Γ— CrossEntropy(student, targets)
     + 0.5 Γ— KL(student_soft / T, teacher_soft / T) Γ— TΒ²

Sparse 2:4 masks are applied via backward hooks: any weight that is exactly zero has its gradient zeroed at every update step, preserving the sparsity pattern permanently.


πŸ“Š Evaluation Results

Model Val CE Loss Val PPL Final Loss Size
Teacher (SmolLM-135M-Instruct) 1.8500 6.36 β€” 270 MB
Dense Student (KD Ξ±=0.5) 2.1210 8.34 2.8336 97 MB (5.02 MB @ 1.58b)
Sparse 2:4 Student (this model) 2.5907 β€” 3.0393 98 MB (5.08 MB @ 1.58b)

The sparse model achieves a val CE loss only 0.74 nats above the teacher while compressing it by ~27Γ— in effective bit-width and having 24.5% of its weights permanently zeroed.


πŸ’¬ Qualitative Outputs (Sample)

Prompt: "What are three tips for staying healthy?"

  1. Reduce your energy to avoid chronic diseases by using a large bowl of water for an appropriate amount of time.
  2. Plant a plant-based diet into small fruits, vegetables, and other organic fats.
  3. Replace them with healthy fats and sugars as well as in the right amount of vitamins and minerals.

Prompt: "Explain what a transformer neural network is in simple terms."

A Neural network is a type of artificial neural networks that uses labeled data to detect objects or objects in order to create output. It consists of the nodes and neurons, such as each other from one branch and outputs, but it can be used to identify natural language features or patterns that are not in the input.

Prompt: "Explain how photosynthesis works."

Photosynthesis is a form of glucose that are absorbed by the process by plants, including oxygen, and other chemical energy. The sun is then used to produce light energy from the environment, releasing light into our cells, and helping to absorb carbon dioxide. During photosynthesis, water is a renewable source of energy with oxygen, where it takes about 30% of oxygen.

⚠️ This is a 25M parameter research model. Factual accuracy is limited β€” it follows instruction format well but may hallucinate content. Do not use for factual lookup, translation, or production applications.


πŸš€ Usage

Because this model uses a custom ternary architecture, it cannot be loaded via AutoModel. You must use the BitGPT class from model.py (included in this repo).

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model import BitGPT

# 1. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")

# 2. Initialise model
model = BitGPT(
    vocab_size=len(tokenizer),  # 49152
    embed_dim=256,
    num_layers=12,
    num_heads=4,
    tie_weights=True,
)

# 3. Load weights
model.load_state_dict(
    torch.load("bitnet_sparse_instruct_15k.pt", map_location="cpu", weights_only=True)
)
model.eval()

# 4. Generate a response
def generate(prompt, max_tokens=150, temperature=0.7, top_k=40):
    chatml = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    ids = tokenizer.encode(chatml, add_special_tokens=False)
    x = torch.tensor([ids])
    generated = []

    with torch.no_grad():
        for _ in range(max_tokens):
            logits = model(x)[:, -1, :].float()
            # top-k sampling
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float("-inf")
            probs = F.softmax(logits / temperature, dim=-1)
            nid = torch.multinomial(probs, 1).item()
            generated.append(nid)
            if "<|im_end|>" in tokenizer.decode([nid]):
                break
            x = torch.cat([x, torch.tensor([[nid]])], dim=1)
            if x.size(1) > 128:
                x = x[:, -128:]

    return tokenizer.decode(generated, skip_special_tokens=True).strip()

print(generate("What are three tips for staying healthy?"))

πŸ“ Files in This Repository

File Description
bitnet_sparse_instruct_15k.pt Model weights (float32, 98MB on disk)
model.py BitGPT + BitLinear + RMSNorm architecture source
README.md This file

πŸ”¬ Research Context

This model is part of an ongoing research project exploring the viability of 1.58-bit language models running entirely on edge devices (CPU/Apple Silicon). The project investigates:

  • Knowledge distillation at extreme compression ratios (135M β†’ 25M params)
  • Combining BitNet quantisation with Sparse 2:4 structured pruning
  • On-device instruction following without cloud inference

The teacher model (SmolLM-135M-Instruct) achieves PPL 6.36; this model reaches PPL equivalent with only 5 MB of effective weight storage β€” a **27Γ— compression** with less than 1.5 nats CE loss degradation.


πŸ“œ License

MIT β€” free to use, modify, and distribute.

πŸ™ Citation / Attribution

If you use this model, please credit:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for 0sparsh2/MiniLM

Finetuned
(191)
this model

Dataset used to train 0sparsh2/MiniLM

Paper for 0sparsh2/MiniLM