| | from flask import Flask, render_template, request |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import numpy as np |
| | import requests |
| | import json |
| | from huggingface_hub import hf_hub_download |
| |
|
| | app = Flask(__name__) |
| | _cache = {} |
| |
|
| |
|
| | def get_sigma(hidden_size: int, seed: int): |
| | rng = np.random.default_rng(seed) |
| | sigma = rng.permutation(hidden_size) |
| | sigma_inv = np.argsort(sigma) |
| | return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long) |
| |
|
| |
|
| | def load_client_components(ee_model_name: str): |
| | if ee_model_name in _cache: |
| | return _cache[ee_model_name] |
| |
|
| | config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| | with open(config_path) as f: |
| | ee_config = json.load(f) |
| |
|
| | hidden_size = ee_config["hidden_size"] |
| | original_model_name = ee_config["original_model"] |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True) |
| |
|
| | original_model = AutoModelForCausalLM.from_pretrained( |
| | original_model_name, |
| | torch_dtype=torch.float32, |
| | device_map="cpu", |
| | trust_remote_code=True, |
| | ) |
| | embed_layer = original_model.model.embed_tokens |
| | lm_head = original_model.lm_head |
| | final_norm = original_model.model.norm |
| | embed_layer.eval() |
| | lm_head.eval() |
| | final_norm.eval() |
| | del original_model |
| |
|
| | _cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size) |
| | return tokenizer, embed_layer, lm_head, final_norm, hidden_size |
| |
|
| |
|
| | def generate_tokens(server_url, tokenizer, embed_layer, lm_head, final_norm, |
| | sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens): |
| | """ |
| | Token-by-token generation. No KV cache β client accumulates all embeddings |
| | and sends the full growing sequence each step. |
| | |
| | Each step: |
| | 1. Encrypt all token embeddings so far with sigma |
| | 2. Send to server β get back last hidden state (sigma-space) |
| | 3. Decrypt last position: apply sigma_inv |
| | 4. Run final_norm + lm_head locally β next token |
| | """ |
| | inputs = tokenizer(formatted_prompt, return_tensors="pt") |
| | input_ids = inputs.input_ids |
| |
|
| | |
| | with torch.no_grad(): |
| | all_plain_embeds = embed_layer(input_ids) |
| |
|
| | generated_ids = [] |
| |
|
| | for step in range(max_new_tokens): |
| | |
| | all_encrypted = all_plain_embeds[..., sigma_t].to(torch.float16) |
| | seq_len = all_encrypted.shape[1] |
| | attention_mask = torch.ones(1, seq_len, dtype=torch.long) |
| |
|
| | payload = { |
| | "inputs_embeds": all_encrypted.tolist(), |
| | "attention_mask": attention_mask.tolist(), |
| | } |
| |
|
| | resp = requests.post(f"{server_url}/generate", json=payload, timeout=120) |
| | if not resp.ok: |
| | raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}") |
| |
|
| | body = resp.json() |
| | if "error" in body: |
| | raise RuntimeError(f"Server error: {body['error']}") |
| |
|
| | |
| | last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) |
| | last_pos_sigma = last_hidden[:, -1:, :] |
| | last_pos_plain = last_pos_sigma[..., sigma_inv_t] |
| |
|
| | |
| | with torch.no_grad(): |
| | normed = final_norm(last_pos_plain) |
| | logits = lm_head(normed) |
| |
|
| | next_token_id = logits[0, -1, :].argmax().item() |
| | generated_ids.append(next_token_id) |
| |
|
| | if next_token_id == tokenizer.eos_token_id: |
| | break |
| |
|
| | |
| | next_id_tensor = torch.tensor([[next_token_id]]) |
| | with torch.no_grad(): |
| | next_embed = embed_layer(next_id_tensor) |
| | all_plain_embeds = torch.cat([all_plain_embeds, next_embed], dim=1) |
| |
|
| | return generated_ids |
| |
|
| |
|
| | @app.route("/", methods=["GET", "POST"]) |
| | def index(): |
| | result = None |
| | error = None |
| | form_data = {} |
| | ee_model_name = 'broadfield-dev/Qwen3-0.6B-dp-ee' |
| | tokenizer, embed_layer, lm_head, final_norm, hidden_size = \ |
| | load_client_components(ee_model_name) |
| | if request.method == "POST": |
| | form_data = request.form.to_dict() |
| | server_url = request.form["server_url"].rstrip("/") |
| | |
| | ee_seed = int(request.form["ee_seed"]) |
| | prompt = request.form["prompt"].strip() |
| | max_tokens = int(request.form.get("max_tokens", 256)) |
| |
|
| | try: |
| | '''tokenizer, embed_layer, lm_head, final_norm, hidden_size = \ |
| | load_client_components(ee_model_name)''' |
| |
|
| | sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed) |
| |
|
| | messages = [{"role": "user", "content": prompt}] |
| | formatted = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=False, |
| | ) |
| |
|
| | gen_ids = generate_tokens( |
| | server_url, tokenizer, embed_layer, lm_head, final_norm, |
| | sigma_t, sigma_inv_t, formatted, max_tokens |
| | ) |
| |
|
| | result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() |
| |
|
| | except RuntimeError as e: |
| | error = str(e) |
| | except requests.exceptions.ConnectionError: |
| | error = f"Could not connect to {server_url} β is the server Space running?" |
| | except Exception as e: |
| | error = f"{type(e).__name__}: {e}" |
| |
|
| | return render_template("client.html", result=result, error=error, form=form_data) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run(host="0.0.0.0", port=7860) |