from fastapi import FastAPI, Query from transformers import Mistral3ForConditionalGeneration, AutoProcessor from typing import Union, Optional, List import torch app = FastAPI() device = "cuda" model_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" text_encoder = Mistral3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device) processor_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" tokenizer = AutoProcessor.from_pretrained(processor_id) def format_text_input(prompts: List[str], system_message: str = None): # Remove [IMG] tokens from prompts to avoid Pixtral validation issues # when truncation is enabled. The processor counts [IMG] tokens and fails # if the count changes after truncation. cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] return [ [ { "role": "system", "content": [{"type": "text", "text": system_message}], }, {"role": "user", "content": [{"type": "text", "text": prompt}]}, ] for prompt in cleaned_txt ] def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, tokenizer: AutoProcessor, prompt: Union[str, List[str]], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, max_sequence_length: int = 512, system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""", hidden_states_layers: List[int] = (10, 20, 30), ): dtype = text_encoder.dtype if dtype is None else dtype device = text_encoder.device if device is None else device prompt = [prompt] if isinstance(prompt, str) else prompt # Format input messages messages_batch = format_text_input(prompts=prompt, system_message=system_message) # Process all messages at once inputs = tokenizer.apply_chat_template( messages_batch, add_generation_prompt=False, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", truncation=True, max_length=max_sequence_length, ) # Move to device input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) # Forward pass through the model with torch.inference_mode(): output = text_encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) # Only use outputs from intermediate layers and stack them out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) out = out.to(dtype=dtype, device=device) batch_size, num_channels, seq_len, hidden_dim = out.shape prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) return prompt_embeds def _prepare_text_ids( x: torch.Tensor, # (B, L, D) or (L, D) t_coord: Optional[torch.Tensor] = None, ): B, L, _ = x.shape out_ids = [] for i in range(B): t = torch.arange(1) if t_coord is None else t_coord[i] h = torch.arange(1) w = torch.arange(1) l = torch.arange(L) coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) return torch.stack(out_ids) def encode_prompt( prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, ): if prompt is None: prompt = "" prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_embeds = _get_mistral_3_small_prompt_embeds( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt, device=device, max_sequence_length=max_sequence_length, ) batch_size, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) text_ids = _prepare_text_ids(prompt_embeds) text_ids = text_ids.to(device) return prompt_embeds, text_ids @app.get("/") def read_root(): return {"message": "API is live. Use the /predict endpoint."} @app.get("/predict") def predict(prompt: str = Query(...)): prompt_embeds, text_ids = encode_prompt( prompt=prompt, device=device, ) return { "response": { "prompt_embeds": prompt_embeds.cpu().tolist(), "text_ids": text_ids.cpu().tolist() } }