broadfield-dev's picture
Update app.py
f171f6e verified
from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download
app = Flask(__name__)
_cache = {}
def get_sigma(hidden_size: int, seed: int):
rng = np.random.default_rng(seed)
sigma = rng.permutation(hidden_size)
sigma_inv = np.argsort(sigma)
return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long)
def load_client_components(ee_model_name: str):
if ee_model_name in _cache:
return _cache[ee_model_name]
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
hidden_size = ee_config["hidden_size"]
original_model_name = ee_config["original_model"]
tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
embed_layer = original_model.model.embed_tokens
lm_head = original_model.lm_head
final_norm = original_model.model.norm
embed_layer.eval()
lm_head.eval()
final_norm.eval()
del original_model
_cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size)
return tokenizer, embed_layer, lm_head, final_norm, hidden_size
def generate_tokens(server_url, tokenizer, embed_layer, lm_head, final_norm,
sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens):
"""
Token-by-token generation. No KV cache β€” client accumulates all embeddings
and sends the full growing sequence each step.
Each step:
1. Encrypt all token embeddings so far with sigma
2. Send to server β†’ get back last hidden state (sigma-space)
3. Decrypt last position: apply sigma_inv
4. Run final_norm + lm_head locally β†’ next token
"""
inputs = tokenizer(formatted_prompt, return_tensors="pt")
input_ids = inputs.input_ids # (1, seq_len)
# Build initial encrypted embeddings for full prompt
with torch.no_grad():
all_plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
generated_ids = []
for step in range(max_new_tokens):
# Encrypt the full sequence so far
all_encrypted = all_plain_embeds[..., sigma_t].to(torch.float16) # (1, seq, hidden)
seq_len = all_encrypted.shape[1]
attention_mask = torch.ones(1, seq_len, dtype=torch.long)
payload = {
"inputs_embeds": all_encrypted.tolist(),
"attention_mask": attention_mask.tolist(),
}
resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
if not resp.ok:
raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}")
body = resp.json()
if "error" in body:
raise RuntimeError(f"Server error: {body['error']}")
# Decrypt last position only
last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
last_pos_sigma = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
last_pos_plain = last_pos_sigma[..., sigma_inv_t] # (1, 1, hidden) plain-space
# Client-side: final norm + lm_head β†’ next token
with torch.no_grad():
normed = final_norm(last_pos_plain)
logits = lm_head(normed) # (1, 1, vocab)
next_token_id = logits[0, -1, :].argmax().item()
generated_ids.append(next_token_id)
if next_token_id == tokenizer.eos_token_id:
break
# Append new token's plain embedding to the growing sequence
next_id_tensor = torch.tensor([[next_token_id]])
with torch.no_grad():
next_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
all_plain_embeds = torch.cat([all_plain_embeds, next_embed], dim=1)
return generated_ids
@app.route("/", methods=["GET", "POST"])
def index():
result = None
error = None
form_data = {}
ee_model_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
load_client_components(ee_model_name)
if request.method == "POST":
form_data = request.form.to_dict()
server_url = request.form["server_url"].rstrip("/")
#ee_model_name = request.form["ee_model_name"].strip()
ee_seed = int(request.form["ee_seed"])
prompt = request.form["prompt"].strip()
max_tokens = int(request.form.get("max_tokens", 256))
try:
'''tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
load_client_components(ee_model_name)'''
sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
messages = [{"role": "user", "content": prompt}]
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # disable Qwen3 thinking mode
)
gen_ids = generate_tokens(
server_url, tokenizer, embed_layer, lm_head, final_norm,
sigma_t, sigma_inv_t, formatted, max_tokens
)
result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
except RuntimeError as e:
error = str(e)
except requests.exceptions.ConnectionError:
error = f"Could not connect to {server_url} β€” is the server Space running?"
except Exception as e:
error = f"{type(e).__name__}: {e}"
return render_template("client.html", result=result, error=error, form=form_data)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)