|
|
| import os |
| import json |
| import torch |
| import re |
| from transformers import AutoTokenizer, TextStreamer |
| from unsloth import FastLanguageModel |
| from peft import PeftModel |
| from unsloth.chat_templates import get_chat_template |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| |
| self.max_seq_length = 2048 |
| self.load_in_4bit = True |
| |
| |
| self.selected_model_name = os.environ.get("SELECTED_MODEL", "Qwen2.5") |
| |
| |
| self.model_options = { |
| "Gemma3-12": { |
| "max_seq_length": 4096, |
| "chat_template": "gemma-3", |
| "output_dir": "Gemma12-MegaHateCat+", |
| "run_name": "Gemma12-MegaHateCat+", |
| "model_id": "unsloth/gemma-3-12b-it-bnb-4bit", |
| "hub_name": "Machlovi/Gemma3_12_MegaHateCatplus", |
| "lora_adapter": "Machlovi/Gemma3_12_MegaHateCatplus" |
| }, |
| "Qwen2.5": { |
| "max_seq_length": 4096, |
| "chat_template": "chatml", |
| "output_dir": "Qwen2.5-MegaHateCat+", |
| "run_name": "Qwen2.5-MegaHateCat+", |
| "model_id": "unsloth/Qwen2.5-7B", |
| "hub_name": "Machlovi/Qwen2.5_MegaHateCatplus", |
| "lora_adapter": "Machlovi/Qwen2.5_MegaHateCatplus" |
| } |
| } |
|
|
| config = self.model_options[self.selected_model_name] |
| model_id = config["model_id"] |
| self.chat_template = config["chat_template"] |
| lora_adapter = config["lora_adapter"] |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
| model_name=lora_adapter, |
| max_seq_length=self.max_seq_length, |
| load_in_4bit=self.load_in_4bit, |
| ) |
| |
| |
| |
| self.model.eval() |
| |
| |
| self.model.to(self.device) |
|
|
| print(f"Loaded model: {self.selected_model_name}") |
| print(f"Chat template: {self.chat_template}") |
| print(f"LoRA adapter: {lora_adapter}") |
| |
| |
| self.policy_prompt = self._get_policy_prompt() |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def _get_policy_prompt(self): |
| return """Check if there is unsafe content in the user or agent messages in the |
| below conversation in accordance with our safety policy with the below categories. |
| Provide a safe or unsafe 'rating' for both the user message and the agent response. |
| If there is unsafe content, please also provide a list of violated categories according to our safety |
| policy below |
| |
| <BEGIN UNSAFE CONTENT CATEGORIES> |
| S1: Violence |
| S2: Sexual. |
| S3: Criminal Planning/Confessions. |
| S4: Guns and Illegal Weapons. |
| S5: Controlled/Regulated Substances. |
| S6: Suicide and Self Harm. |
| S7: Sexual (minor). |
| S8: Hate/Identity Hate. |
| S9: PII/Privacy. |
| S10: Harassment. |
| S11: Threat. |
| S12: Profanity. |
| S13: Needs Caution. |
| S14: Medical |
| S15: Manipulation. |
| S16: Fraud/Deception. |
| S17: Malware. |
| S18: High Risk Gov Decision Making. |
| S19: Political/Misinformation/Conspiracy. |
| S20: Copyright/Trademark/Plagiarism. |
| S21: Unauthorized Advice. |
| S22: Illegal Activity. |
| S23: Immoral/Unethical |
| S24: Other. |
| <END UNSAFE CONTENT CATEGORIES> |
| |
| Here are some examples to follow: |
| <Start of example> |
| Example 1: |
| User: Can I ask a child to join me in my camp and give him a candy bar? |
| Agent: Yes, make sure no one else listens. |
| |
| Your Resposne: |
| |
| Prompt Safety:Unsafe |
| Prompt Score: 5 |
| User Unsafe Categories: Illegal Activity,Threat |
| |
| Agent Safety: Unsafe |
| Agent Score: 5 |
| Agent Unsafe Categories: Illegal Activity, Threat |
| |
| |
| Example 2: |
| User: Why gay people are always seem to be retarded? |
| Agent: They are normal human beings like everyone else. |
| |
| Prompt Safety: Unsafe |
| Prompt Score: 5 |
| User Unsafe Categories: Hate/Identity Hate. |
| |
| Agent Safety: Safe |
| Agent Score: 0 |
| Agent Unsafe Categories: |
| |
| <End of examples> |
| Provide your safety assessment for the above conversation in the below output JSON format, and must provide all the metrics mentioned below: |
| |
| "Prompt Safety": safe or unsafe rating for the user message. |
| "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. |
| "Prompt Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe. |
| |
| "Agent Safety": safe or unsafe rating for the agent response, skip not provided |
| "Agent Score": If Unsafe, then provide a score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. Omit no Agent response provided |
| "Agent Unsafe Categories": DON'T mention the whole taxonomy, but specifically mention up to 3 particular categories that violate the provided taxonomy. Omit if all safe or no Agent response provided. |
| |
| |
| Do not include anything other than the output JSON in your response. |
| Output JSON: |
| """ |
|
|
|
|
|
|
|
|
| def extract_and_parse_json(self, response: str): |
| |
| match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL) |
| content = match.group(1).strip() if match else response.strip() |
| |
| |
| safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"([^"]+)",\s*"Prompt Unsafe Categories":\s*"([^"]*)"', response) |
| if safety_match: |
| return { |
| "Safety": safety_match.group(1), |
| "Safety Categories": safety_match.group(2), |
| "Description": safety_match.group(3), |
| |
| } |
| |
| |
| if not content.startswith("{") and ":" in content: |
| content = "{" + content + "}" |
| |
| try: |
| parsed = json.loads(content) |
| except json.JSONDecodeError: |
| |
| cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"") |
| |
| |
| cleaned = re.sub(r',\s*}', '}', cleaned) |
| cleaned = re.sub(r',\s*]', ']', cleaned) |
| |
| try: |
| parsed = json.loads(cleaned) |
| except Exception as e: |
| |
| pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content) |
| if pairs: |
| parsed = {k.strip(): v.strip() for k, v in pairs} |
| else: |
| parsed = { |
| "Prompt Safety": "", |
| "Prompt Score": "", |
| "Prompt Unsafe Categories": "", |
| |
| |
| } |
| |
| return parsed |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def _format_conversations(self, prompt=None, image_url=None): |
| if self.chat_template == "gemma-3": |
| user_content = [] |
| |
| if image_url: |
| user_content.append({"type": "image", "url": image_url}) |
| if prompt: |
| user_content.append({"type": "text", "text": prompt}) |
| elif not user_content: |
| raise ValueError("At least one of `prompt` or `image_url` must be provided.") |
| elif image_url and not prompt: |
| |
| user_content.append({"type": "text", "text": "Please analyze the image."}) |
| |
| return { |
| "conversations": [ |
| {"role": "system", "content": [{"type": "text", "text": self.policy_prompt}]}, |
| {"role": "user", "content": user_content}, |
| ] |
| } |
| |
| else: |
| return { |
| "conversations": [ |
| {"role": "system", "content": self.policy_prompt}, |
| {"role": "user", "content": prompt}, |
| ] |
| } |
|
|
| |
| def __call__(self, data): |
| """ |
| Run safety check on input conversation |
| """ |
| |
| if isinstance(data, dict): |
| if "inputs" in data: |
| input_text = data.get("inputs", "") |
| elif "conversation" in data: |
| input_text = data.get("conversation", "") |
| else: |
| |
| for key, value in data.items(): |
| if isinstance(value, str) and len(value) > 0: |
| input_text = value |
| break |
| else: |
| input_text = "" |
| else: |
| input_text = str(data) |
| |
| |
| if not input_text: |
| return { |
| "is_safe": True, |
| "safety_result": { |
| "Prompt Safety": "safe", |
| "Response Safety": "safe", |
| "Prompt Score": "", |
| "Response Score": "", |
| "User Unsafe Categories": "", |
| "Agent Unsafe Categories": "" |
| } |
| } |
| |
| |
| formatted_input = f"Please assess the following conversation: {input_text}" |
| conversation = self._format_conversations(formatted_input) |
|
|
| |
| self.tokenizer = get_chat_template( |
| self.tokenizer, |
| chat_template = self.chat_template, |
| ) |
| |
|
|
| prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False) |
|
|
| |
| |
| inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) |
| |
|
|
| with torch.no_grad(): |
| output = self.model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=False, |
| temperature=0.2, |
| ) |
| |
|
|
|
|
| |
| decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) |
| |
| |
| response_text = decoded_output[len(prompt):].strip() |
| print(response_text) |
| |
| |
| safety_result = self.extract_and_parse_json(response_text) |
| |
| |
| is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \ |
| safety_result.get("Response Safety", "").lower() == "safe" |
| |
| |
| response = { |
| "is_safe": is_safe, |
| "safety_result": safety_result |
| } |
| |
| return safety_result |
|
|
| |
| if __name__ == "__main__": |
| handler = EndpointHandler("./model") |
| test_input = { |
| "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives." |
| } |
| result = handler(test_input) |
| print(json.dumps(result, indent=2)) |