Instructions to use minishlab/potion-code-16M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Model2Vec
How to use minishlab/potion-code-16M with Model2Vec:
from model2vec import StaticModel model = StaticModel.from_pretrained("minishlab/potion-code-16M") - Notebooks
- Google Colab
- Kaggle
| """Reproduction script for potion-code-16M. | |
| Runs the full pipeline: distill → tokenlearn → contrastive fine-tuning. | |
| Requirements: | |
| pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops | |
| The three model checkpoints are saved to: | |
| ./models/potion-code-16M-distilled | |
| ./models/potion-code-16M-tokenlearn | |
| ./models/potion-code-16M-contrastive ← final model | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import random | |
| import numpy as np | |
| import torch | |
| from datasets import Dataset, concatenate_datasets, load_dataset | |
| from huggingface_hub import snapshot_download | |
| from model2vec import StaticModel | |
| from model2vec.distill import distill_from_model | |
| from model2vec.distill.inference import post_process_embeddings | |
| from pathlib import Path | |
| from sentence_transformers import ( | |
| SentenceTransformer, | |
| SentenceTransformerTrainer, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| from sentence_transformers.losses import MultipleNegativesRankingLoss | |
| from sentence_transformers.models import StaticEmbedding | |
| from sentence_transformers.training_args import BatchSamplers | |
| from skeletoken import TokenizerModel | |
| from sklearn.decomposition import PCA | |
| from tokenlearn.losses import Loss | |
| from tokenlearn.model import StaticModelForFineTuning | |
| from tokenlearn.utils import create_vocab | |
| from transformers import AutoModel, AutoTokenizer | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| TEACHER_MODEL = "nomic-ai/CodeRankEmbed" | |
| OUTPUT_DIR = Path("models") | |
| # Distill | |
| VOCAB_SIZE = 42_000 # extra tokens mined from CornStack → ~62.5k total → ~16M params | |
| PCA_DIMS = 256 | |
| SIF_COEFFICIENT = 1e-4 | |
| # Tokenlearn | |
| TOKENLEARN_DOCS_DATASET = "minishlab/tokenlearn-cornstack-docs-coderankembed" | |
| TOKENLEARN_QUERIES_DATASET = "minishlab/tokenlearn-cornstack-queries-coderankembed" | |
| TOKENLEARN_LANGUAGES = ["go", "java", "javascript", "php", "python", "ruby"] | |
| TOKENLEARN_MAX_PER_LANGUAGE = 20_000 # 20k docs + 20k queries × 6 langs = 240k total | |
| TOKENLEARN_LR = 1e-3 | |
| TOKENLEARN_MAX_EPOCHS = 20 # early stopping (patience=5) typically kicks in earlier | |
| TOKENLEARN_BATCH_SIZE = 128 | |
| # Contrastive | |
| CORNSTACK_DATASETS = { | |
| "python": "nomic-ai/cornstack-python-v1", | |
| "java": "nomic-ai/cornstack-java-v1", | |
| "php": "nomic-ai/cornstack-php-v1", | |
| "go": "nomic-ai/cornstack-go-v1", | |
| "javascript": "nomic-ai/cornstack-javascript-v1", | |
| "ruby": "nomic-ai/cornstack-ruby-v1", | |
| } | |
| CONTRASTIVE_MAX_PER_LANGUAGE = 20_000 # 20k × 6 langs = 120k pairs total | |
| CONTRASTIVE_LR = 5e-3 | |
| CONTRASTIVE_EPOCHS = 3 | |
| CONTRASTIVE_BATCH_SIZE = 512 | |
| CONTRASTIVE_SEED = 42 | |
| def apply_post_sif(model: StaticModel, pca_dims: int, sif_coefficient: float) -> StaticModel: | |
| """Apply post-SIF re-regularization to a static model.""" | |
| embeddings_np = model.embedding.astype(np.float32) | |
| processed, weights = post_process_embeddings(embeddings_np, pca_dims=pca_dims, sif_coefficient=sif_coefficient) | |
| logger.info("post_process_embeddings: %s → %s", embeddings_np.shape, processed.shape) | |
| model.embedding = processed | |
| model.weights = weights | |
| return model | |
| def run_distill(save_path: Path) -> None: | |
| """Distill CodeRankEmbed into a static model with an extended code vocabulary.""" | |
| logger.info("Downloading %s ...", TEACHER_MODEL) | |
| local_path = snapshot_download(TEACHER_MODEL) | |
| model = AutoModel.from_pretrained(local_path, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True, use_fast=True) | |
| # Load tokenlearn corpus texts for vocab mining (docs + queries, 20k/lang) | |
| logger.info("Loading texts for vocabulary mining ...") | |
| shards = [] | |
| for lang in TOKENLEARN_LANGUAGES: | |
| docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]") | |
| queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]") | |
| shards.extend([docs, queries]) | |
| corpus = concatenate_datasets(shards) | |
| texts: list[str] = list(corpus["text"]) | |
| logger.info("Loaded %d texts for vocab mining.", len(texts)) | |
| logger.info("Mining vocabulary (target size=%d) ...", VOCAB_SIZE) | |
| vocab = create_vocab(texts=texts, vocab_size=VOCAB_SIZE) | |
| logger.info("Mined %d tokens.", len(vocab)) | |
| # Filter: keep only new single-token entries not already in CodeRankEmbed vocabulary. | |
| tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer).prune_added_tokens() | |
| preprocessor = tokenizer_model.preprocessor | |
| seen = set(tokenizer_model.sorted_vocabulary) | |
| filtered = [] | |
| for token in vocab: | |
| preprocessed = preprocessor.preprocess(token) | |
| if len(preprocessed) == 1 and preprocessed[0] not in seen: | |
| seen.add(preprocessed[0]) | |
| filtered.append(preprocessed[0]) | |
| logger.info("Vocabulary after filtering: %d tokens added to CodeRankEmbed.", len(filtered)) | |
| # NomicBERT requires monkey-patched embedding accessors. | |
| model.get_input_embeddings = lambda: model.embeddings.word_embeddings | |
| model.set_input_embeddings = lambda v: setattr(model.embeddings, "word_embeddings", v) | |
| logger.info("Distilling (pca_dims=%d, sif=%g) ...", PCA_DIMS, SIF_COEFFICIENT) | |
| static_model = distill_from_model( | |
| model=model, | |
| tokenizer=tokenizer, | |
| vocabulary=filtered, | |
| pca_dims=PCA_DIMS, | |
| sif_coefficient=SIF_COEFFICIENT, | |
| pooling="mean", | |
| quantize_to="float32", | |
| ) | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| static_model.save_pretrained(str(save_path)) | |
| logger.info( | |
| "Distilled model saved to %s (vocab=%d, dims=%d)", | |
| save_path, | |
| static_model.embedding.shape[0], | |
| static_model.embedding.shape[1], | |
| ) | |
| def run_tokenlearn(base_model_path: Path, save_path: Path) -> None: | |
| """Fine-tune the distilled model on CornStack using cosine similarity loss.""" | |
| # Load 20k docs + 20k queries per language → 240k total | |
| logger.info( | |
| "Loading tokenlearn data (docs + queries, %d/lang × %d langs) ...", | |
| TOKENLEARN_MAX_PER_LANGUAGE, | |
| len(TOKENLEARN_LANGUAGES), | |
| ) | |
| shards = [] | |
| for lang in TOKENLEARN_LANGUAGES: | |
| docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]") | |
| queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]") | |
| shards.extend([docs, queries]) | |
| dataset = concatenate_datasets(shards) | |
| logger.info("Total samples: %d", len(dataset)) | |
| train_txt: list[str] = list(dataset["text"]) | |
| train_vec = np.array(dataset["embedding"], dtype=np.float32) | |
| non_nan_mask = ~np.isnan(train_vec).any(axis=1) | |
| train_txt = np.array(train_txt)[non_nan_mask].tolist() | |
| train_vec = train_vec[non_nan_mask] | |
| logger.info("Loaded %d samples, raw vector shape: %s", len(train_txt), train_vec.shape) | |
| logger.info("Fitting PCA to %d dims ...", PCA_DIMS) | |
| pca = PCA(n_components=PCA_DIMS) | |
| train_vec = pca.fit_transform(train_vec) | |
| logger.info("Explained variance: %.4f. Shape: %s", pca.explained_variance_ratio_.cumsum()[-1], train_vec.shape) | |
| logger.info("Loading base model from %s ...", base_model_path) | |
| base_model = StaticModel.from_pretrained(str(base_model_path), force_download=False) | |
| if base_model.embedding.dtype != np.float32: | |
| base_model.embedding = base_model.embedding.astype(np.float32) | |
| trainable = StaticModelForFineTuning.from_static_model( | |
| model=base_model, | |
| out_dim=PCA_DIMS, | |
| loss=Loss("cosine"), | |
| ) | |
| logger.info( | |
| "Training tokenlearn (lr=%g, max_epochs=%d, batch=%d) ...", | |
| TOKENLEARN_LR, | |
| TOKENLEARN_MAX_EPOCHS, | |
| TOKENLEARN_BATCH_SIZE, | |
| ) | |
| trainable.fit( | |
| X=train_txt, | |
| y=torch.from_numpy(train_vec.astype(np.float32)), | |
| batch_size=TOKENLEARN_BATCH_SIZE, | |
| learning_rate=TOKENLEARN_LR, | |
| max_epochs=TOKENLEARN_MAX_EPOCHS, | |
| early_stopping_patience=5, | |
| use_wandb=False, | |
| ) | |
| logger.info("Tokenlearn training complete.") | |
| trained_model = trainable.to_static_model() | |
| trained_model = apply_post_sif(trained_model, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT) | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| trained_model.save_pretrained(str(save_path)) | |
| logger.info("Tokenlearn model saved to %s", save_path) | |
| def run_contrastive(base_model_path: Path, save_path: Path) -> None: | |
| """Fine-tune the tokenlearn model using MultipleNegativesRankingLoss on CornStack pairs.""" | |
| random.seed(CONTRASTIVE_SEED) | |
| logger.info( | |
| "Streaming CornStack pairs (%d/lang × %d langs) ...", CONTRASTIVE_MAX_PER_LANGUAGE, len(CORNSTACK_DATASETS) | |
| ) | |
| all_queries: list[str] = [] | |
| all_docs: list[str] = [] | |
| for lang, hf_name in CORNSTACK_DATASETS.items(): | |
| hf_ds = load_dataset(hf_name, split="train", streaming=True) | |
| hf_ds = hf_ds.shuffle(seed=CONTRASTIVE_SEED, buffer_size=10_000) | |
| kept = 0 | |
| seen_q: set[str] = set() | |
| seen_d: set[str] = set() | |
| for row in hf_ds: | |
| q, d = row.get("query"), row.get("document") | |
| if not isinstance(q, str) or not isinstance(d, str): | |
| continue | |
| if len(q) < 32 or len(d) < 32: | |
| continue | |
| if q in seen_q or d in seen_d: | |
| continue | |
| seen_q.add(q) | |
| seen_d.add(d) | |
| all_queries.append(q) | |
| all_docs.append(d) | |
| kept += 1 | |
| if kept >= CONTRASTIVE_MAX_PER_LANGUAGE: | |
| break | |
| logger.info(" %s: %d pairs", lang, kept) | |
| logger.info("Total pairs: %d", len(all_queries)) | |
| train_dataset = Dataset.from_dict({"anchor": all_queries, "positive": all_docs}) | |
| static_embedding = StaticEmbedding.from_model2vec(str(base_model_path)) | |
| model = SentenceTransformer(modules=[static_embedding]) | |
| loss = MultipleNegativesRankingLoss(model) | |
| training_args = SentenceTransformerTrainingArguments( | |
| output_dir=str(save_path) + "-checkpoints", | |
| num_train_epochs=CONTRASTIVE_EPOCHS, | |
| per_device_train_batch_size=CONTRASTIVE_BATCH_SIZE, | |
| learning_rate=CONTRASTIVE_LR, | |
| warmup_steps=0.1, | |
| fp16=False, | |
| bf16=False, | |
| batch_sampler=BatchSamplers.NO_DUPLICATES, | |
| save_strategy="no", | |
| logging_steps=100, | |
| logging_first_step=True, | |
| report_to=[], | |
| ) | |
| logger.info( | |
| "Training contrastive (lr=%g, epochs=%d, batch=%d) ...", | |
| CONTRASTIVE_LR, | |
| CONTRASTIVE_EPOCHS, | |
| CONTRASTIVE_BATCH_SIZE, | |
| ) | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| loss=loss, | |
| ) | |
| trainer.train() | |
| logger.info("Contrastive training complete.") | |
| base_m2v = StaticModel.from_pretrained(str(base_model_path), force_download=False) | |
| base_m2v.embedding = model[0].embedding.weight.detach().cpu().float().numpy() | |
| final_model = apply_post_sif(base_m2v, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT) | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| final_model.save_pretrained(str(save_path)) | |
| logger.info("Final model saved to %s", save_path) | |
| if __name__ == "__main__": | |
| distilled_path = OUTPUT_DIR / "potion-code-16M-distilled" | |
| tokenlearn_path = OUTPUT_DIR / "potion-code-16M-tokenlearn" | |
| contrastive_path = OUTPUT_DIR / "potion-code-16M-contrastive" | |
| logger.info("=== Step 1/3: Distill ===") | |
| run_distill(save_path=distilled_path) | |
| logger.info("=== Step 2/3: Tokenlearn ===") | |
| run_tokenlearn(base_model_path=distilled_path, save_path=tokenlearn_path) | |
| logger.info("=== Step 3/3: Contrastive ===") | |
| run_contrastive(base_model_path=tokenlearn_path, save_path=contrastive_path) | |
| logger.info("Done. Final model: %s", contrastive_path) | |