bitliu commited on
Commit
8208a85
ยท
1 Parent(s): 9d36b35

Signed-off-by: bitliu <[email protected]>

Files changed (1) hide show
  1. app.py +189 -37
app.py CHANGED
@@ -1,7 +1,11 @@
1
  import streamlit as st
2
  import streamlit.components.v1 as components
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
 
 
 
 
5
 
6
  # ============== Model Configurations ==============
7
  MODELS = {
@@ -10,11 +14,20 @@ MODELS = {
10
  "description": "Classifies prompts into academic/professional categories.",
11
  "type": "sequence",
12
  "labels": {
13
- 0: ("biology", "๐Ÿงฌ"), 1: ("business", "๐Ÿ’ผ"), 2: ("chemistry", "๐Ÿงช"),
14
- 3: ("computer science", "๐Ÿ’ป"), 4: ("economics", "๐Ÿ“ˆ"), 5: ("engineering", "โš™๏ธ"),
15
- 6: ("health", "๐Ÿฅ"), 7: ("history", "๐Ÿ“œ"), 8: ("law", "โš–๏ธ"),
16
- 9: ("math", "๐Ÿ”ข"), 10: ("other", "๐Ÿ“ฆ"), 11: ("philosophy", "๐Ÿค”"),
17
- 12: ("physics", "โš›๏ธ"), 13: ("psychology", "๐Ÿง "),
 
 
 
 
 
 
 
 
 
18
  },
19
  "demo": "What is photosynthesis and how does it work?",
20
  },
@@ -37,12 +50,24 @@ MODELS = {
37
  "description": "Detects the primary type of PII in the text.",
38
  "type": "sequence",
39
  "labels": {
40
- 0: ("AGE", "๐ŸŽ‚"), 1: ("CREDIT_CARD", "๐Ÿ’ณ"), 2: ("DATE_TIME", "๐Ÿ“…"),
41
- 3: ("DOMAIN_NAME", "๐ŸŒ"), 4: ("EMAIL_ADDRESS", "๐Ÿ“ง"), 5: ("GPE", "๐Ÿ—บ๏ธ"),
42
- 6: ("IBAN_CODE", "๐Ÿฆ"), 7: ("IP_ADDRESS", "๐Ÿ–ฅ๏ธ"), 8: ("NO_PII", "โœ…"),
43
- 9: ("NRP", "๐Ÿ‘ฅ"), 10: ("ORGANIZATION", "๐Ÿข"), 11: ("PERSON", "๐Ÿ‘ค"),
44
- 12: ("PHONE_NUMBER", "๐Ÿ“ž"), 13: ("STREET_ADDRESS", "๐Ÿ "), 14: ("TITLE", "๐Ÿ“›"),
45
- 15: ("US_DRIVER_LICENSE", "๐Ÿš—"), 16: ("US_SSN", "๐Ÿ”"), 17: ("ZIP_CODE", "๐Ÿ“ฎ"),
 
 
 
 
 
 
 
 
 
 
 
 
46
  },
47
  "demo": "My email is [email protected] and my phone is 555-123-4567",
48
  },
@@ -53,6 +78,32 @@ MODELS = {
53
  "labels": None,
54
  "demo": "John Smith works at Microsoft in Seattle, his email is [email protected]",
55
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
 
@@ -78,7 +129,29 @@ def classify_sequence(text: str, model_id: str, labels: dict) -> tuple:
78
  pred_class = torch.argmax(probs).item()
79
  label_name, emoji = labels[pred_class]
80
  confidence = probs[pred_class].item()
81
- all_scores = {f"{labels[i][1]} {labels[i][0]}": float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return label_name, emoji, confidence, all_scores
83
 
84
 
@@ -86,7 +159,13 @@ def classify_tokens(text: str, model_id: str) -> list:
86
  """Token-level NER classification."""
87
  tokenizer, model = load_model(model_id, "token")
88
  id2label = model.config.id2label
89
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True)
 
 
 
 
 
 
90
  offset_mapping = inputs.pop("offset_mapping")[0].tolist()
91
  with torch.no_grad():
92
  outputs = model(**inputs)
@@ -101,7 +180,11 @@ def classify_tokens(text: str, model_id: str) -> list:
101
  if current_entity:
102
  entities.append(current_entity)
103
  current_entity = {"type": label[2:], "start": start, "end": end}
104
- elif label.startswith("I-") and current_entity and label[2:] == current_entity["type"]:
 
 
 
 
105
  current_entity["end"] = end
106
  else:
107
  if current_entity:
@@ -110,7 +193,7 @@ def classify_tokens(text: str, model_id: str) -> list:
110
  if current_entity:
111
  entities.append(current_entity)
112
  for e in entities:
113
- e["text"] = text[e["start"]:e["end"]]
114
  return entities
115
 
116
 
@@ -119,13 +202,21 @@ def create_highlighted_html(text: str, entities: list) -> str:
119
  if not entities:
120
  return f'<div style="padding:15px;background:#f0f0f0;border-radius:8px;">{text}</div>'
121
  html = text
122
- colors = {"EMAIL_ADDRESS": "#ff6b6b", "PHONE_NUMBER": "#4ecdc4", "PERSON": "#45b7d1",
123
- "STREET_ADDRESS": "#96ceb4", "US_SSN": "#d63384", "CREDIT_CARD": "#fd7e14",
124
- "ORGANIZATION": "#6f42c1", "GPE": "#20c997", "IP_ADDRESS": "#0dcaf0"}
 
 
 
 
 
 
 
 
125
  for e in sorted(entities, key=lambda x: x["start"], reverse=True):
126
  color = colors.get(e["type"], "#ffc107")
127
  span = f'<span style="background:{color};padding:2px 6px;border-radius:4px;color:white;" title="{e["type"]}">{e["text"]}</span>'
128
- html = html[:e["start"]] + span + html[e["end"]:]
129
  return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
130
 
131
 
@@ -135,10 +226,15 @@ def main():
135
  # Header with logo
136
  col1, col2 = st.columns([1, 4])
137
  with col1:
138
- st.image("https://github.com/vllm-project/semantic-router/blob/main/website/static/img/vllm.png?raw=true", width=150)
 
 
 
139
  with col2:
140
  st.title("๐Ÿง  LLM Semantic Router")
141
- st.markdown("**Intelligent Router for Mixture-of-Models** | Part of the [vLLM](https://github.com/vllm-project/vllm) ecosystem")
 
 
142
 
143
  st.markdown("---")
144
 
@@ -161,18 +257,70 @@ def main():
161
 
162
  # Main content
163
  st.subheader("๐Ÿ“ Input")
164
- text_input = st.text_area(
165
- "Enter text to analyze:",
166
- value=model_config["demo"],
167
- height=120,
168
- placeholder="Type your text here..."
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  st.markdown("---")
172
 
173
  # Analyze button
174
  if st.button("๐Ÿ” Analyze", type="primary", use_container_width=True):
175
- if not text_input.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  st.warning("Please enter some text to analyze.")
177
  else:
178
  with st.spinner("Analyzing..."):
@@ -185,14 +333,14 @@ def main():
185
  "label": label,
186
  "emoji": emoji,
187
  "confidence": conf,
188
- "scores": scores
189
  }
190
  else:
191
  entities = classify_tokens(text_input, model_config["id"])
192
  st.session_state.result = {
193
  "type": "token",
194
  "entities": entities,
195
- "text": text_input
196
  }
197
 
198
  # Display results
@@ -200,24 +348,28 @@ def main():
200
  st.markdown("---")
201
  st.subheader("๐Ÿ“Š Results")
202
  result = st.session_state.result
203
- if result["type"] == "sequence":
204
  col1, col2 = st.columns([1, 1])
205
  with col1:
206
  st.success(f"{result['emoji']} **{result['label']}**")
207
  st.metric("Confidence", f"{result['confidence']:.1%}")
208
  with col2:
209
  st.markdown("**All Scores:**")
210
- sorted_scores = dict(sorted(result["scores"].items(), key=lambda x: x[1], reverse=True))
 
 
211
  for k, v in sorted_scores.items():
212
  st.progress(v, text=f"{k}: {v:.1%}")
213
- else:
214
  entities = result["entities"]
215
  if entities:
216
  st.success(f"Found {len(entities)} PII entity(s)")
217
  for e in entities:
218
  st.markdown(f"- **{e['type']}**: `{e['text']}`")
219
  st.markdown("### Highlighted Text")
220
- components.html(create_highlighted_html(result["text"], entities), height=150)
 
 
221
  else:
222
  st.info("โœ… No PII detected")
223
 
@@ -235,9 +387,9 @@ def main():
235
  <b>GitHub</b>: <a href="https://github.com/vllm-project/semantic-router">vllm-project/semantic-router</a>
236
  </div>
237
  """,
238
- unsafe_allow_html=True
239
  )
240
 
241
 
242
  if __name__ == "__main__":
243
- main()
 
1
  import streamlit as st
2
  import streamlit.components.v1 as components
3
  import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ AutoModelForTokenClassification,
8
+ )
9
 
10
  # ============== Model Configurations ==============
11
  MODELS = {
 
14
  "description": "Classifies prompts into academic/professional categories.",
15
  "type": "sequence",
16
  "labels": {
17
+ 0: ("biology", "๐Ÿงฌ"),
18
+ 1: ("business", "๐Ÿ’ผ"),
19
+ 2: ("chemistry", "๐Ÿงช"),
20
+ 3: ("computer science", "๐Ÿ’ป"),
21
+ 4: ("economics", "๐Ÿ“ˆ"),
22
+ 5: ("engineering", "โš™๏ธ"),
23
+ 6: ("health", "๐Ÿฅ"),
24
+ 7: ("history", "๐Ÿ“œ"),
25
+ 8: ("law", "โš–๏ธ"),
26
+ 9: ("math", "๐Ÿ”ข"),
27
+ 10: ("other", "๐Ÿ“ฆ"),
28
+ 11: ("philosophy", "๐Ÿค”"),
29
+ 12: ("physics", "โš›๏ธ"),
30
+ 13: ("psychology", "๐Ÿง "),
31
  },
32
  "demo": "What is photosynthesis and how does it work?",
33
  },
 
50
  "description": "Detects the primary type of PII in the text.",
51
  "type": "sequence",
52
  "labels": {
53
+ 0: ("AGE", "๐ŸŽ‚"),
54
+ 1: ("CREDIT_CARD", "๐Ÿ’ณ"),
55
+ 2: ("DATE_TIME", "๐Ÿ“…"),
56
+ 3: ("DOMAIN_NAME", "๐ŸŒ"),
57
+ 4: ("EMAIL_ADDRESS", "๐Ÿ“ง"),
58
+ 5: ("GPE", "๐Ÿ—บ๏ธ"),
59
+ 6: ("IBAN_CODE", "๐Ÿฆ"),
60
+ 7: ("IP_ADDRESS", "๐Ÿ–ฅ๏ธ"),
61
+ 8: ("NO_PII", "โœ…"),
62
+ 9: ("NRP", "๐Ÿ‘ฅ"),
63
+ 10: ("ORGANIZATION", "๐Ÿข"),
64
+ 11: ("PERSON", "๐Ÿ‘ค"),
65
+ 12: ("PHONE_NUMBER", "๐Ÿ“ž"),
66
+ 13: ("STREET_ADDRESS", "๐Ÿ "),
67
+ 14: ("TITLE", "๐Ÿ“›"),
68
+ 15: ("US_DRIVER_LICENSE", "๐Ÿš—"),
69
+ 16: ("US_SSN", "๐Ÿ”"),
70
+ 17: ("ZIP_CODE", "๐Ÿ“ฎ"),
71
  },
72
  "demo": "My email is [email protected] and my phone is 555-123-4567",
73
  },
 
78
  "labels": None,
79
  "demo": "John Smith works at Microsoft in Seattle, his email is [email protected]",
80
  },
81
+ "๐Ÿ˜ค Dissatisfaction Detector": {
82
+ "id": "llm-semantic-router/dissat-detector",
83
+ "description": "Detects user dissatisfaction in conversational AI interactions. Classifies user follow-up messages as satisfied (SAT) or dissatisfied (DISSAT).",
84
+ "type": "dialogue",
85
+ "labels": {0: ("SAT", "๐ŸŸข"), 1: ("DISSAT", "๐Ÿ”ด")},
86
+ "demo": {
87
+ "query": "Find a restaurant nearby",
88
+ "response": "I found Italian Kitchen for you.",
89
+ "followup": "Show me other options",
90
+ },
91
+ },
92
+ "๐Ÿ” Dissatisfaction Explainer": {
93
+ "id": "llm-semantic-router/dissat-explainer",
94
+ "description": "Explains why a user is dissatisfied. Stage 2 of hierarchical dissatisfaction detection - classifies into NEED_CLARIFICATION, WRONG_ANSWER, or WANT_DIFFERENT.",
95
+ "type": "dialogue",
96
+ "labels": {
97
+ 0: ("NEED_CLARIFICATION", "โ“"),
98
+ 1: ("WRONG_ANSWER", "โŒ"),
99
+ 2: ("WANT_DIFFERENT", "๐Ÿ”„"),
100
+ },
101
+ "demo": {
102
+ "query": "Book a table for 2",
103
+ "response": "Table for 3 confirmed",
104
+ "followup": "No, I said 2 people not 3",
105
+ },
106
+ },
107
  }
108
 
109
 
 
129
  pred_class = torch.argmax(probs).item()
130
  label_name, emoji = labels[pred_class]
131
  confidence = probs[pred_class].item()
132
+ all_scores = {
133
+ f"{labels[i][1]} {labels[i][0]}": float(probs[i]) for i in range(len(labels))
134
+ }
135
+ return label_name, emoji, confidence, all_scores
136
+
137
+
138
+ def classify_dialogue(
139
+ query: str, response: str, followup: str, model_id: str, labels: dict
140
+ ) -> tuple:
141
+ """Classify dialogue using sequence classification model with special format."""
142
+ tokenizer, model = load_model(model_id, "sequence")
143
+ # Format input as per model requirements
144
+ text = f"[USER QUERY] {query}\n[SYSTEM RESPONSE] {response}\n[USER FOLLOWUP] {followup}"
145
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
146
+ with torch.no_grad():
147
+ outputs = model(**inputs)
148
+ probs = torch.softmax(outputs.logits, dim=-1)[0]
149
+ pred_class = torch.argmax(probs).item()
150
+ label_name, emoji = labels[pred_class]
151
+ confidence = probs[pred_class].item()
152
+ all_scores = {
153
+ f"{labels[i][1]} {labels[i][0]}": float(probs[i]) for i in range(len(labels))
154
+ }
155
  return label_name, emoji, confidence, all_scores
156
 
157
 
 
159
  """Token-level NER classification."""
160
  tokenizer, model = load_model(model_id, "token")
161
  id2label = model.config.id2label
162
+ inputs = tokenizer(
163
+ text,
164
+ return_tensors="pt",
165
+ truncation=True,
166
+ max_length=512,
167
+ return_offsets_mapping=True,
168
+ )
169
  offset_mapping = inputs.pop("offset_mapping")[0].tolist()
170
  with torch.no_grad():
171
  outputs = model(**inputs)
 
180
  if current_entity:
181
  entities.append(current_entity)
182
  current_entity = {"type": label[2:], "start": start, "end": end}
183
+ elif (
184
+ label.startswith("I-")
185
+ and current_entity
186
+ and label[2:] == current_entity["type"]
187
+ ):
188
  current_entity["end"] = end
189
  else:
190
  if current_entity:
 
193
  if current_entity:
194
  entities.append(current_entity)
195
  for e in entities:
196
+ e["text"] = text[e["start"] : e["end"]]
197
  return entities
198
 
199
 
 
202
  if not entities:
203
  return f'<div style="padding:15px;background:#f0f0f0;border-radius:8px;">{text}</div>'
204
  html = text
205
+ colors = {
206
+ "EMAIL_ADDRESS": "#ff6b6b",
207
+ "PHONE_NUMBER": "#4ecdc4",
208
+ "PERSON": "#45b7d1",
209
+ "STREET_ADDRESS": "#96ceb4",
210
+ "US_SSN": "#d63384",
211
+ "CREDIT_CARD": "#fd7e14",
212
+ "ORGANIZATION": "#6f42c1",
213
+ "GPE": "#20c997",
214
+ "IP_ADDRESS": "#0dcaf0",
215
+ }
216
  for e in sorted(entities, key=lambda x: x["start"], reverse=True):
217
  color = colors.get(e["type"], "#ffc107")
218
  span = f'<span style="background:{color};padding:2px 6px;border-radius:4px;color:white;" title="{e["type"]}">{e["text"]}</span>'
219
+ html = html[: e["start"]] + span + html[e["end"] :]
220
  return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
221
 
222
 
 
226
  # Header with logo
227
  col1, col2 = st.columns([1, 4])
228
  with col1:
229
+ st.image(
230
+ "https://github.com/vllm-project/semantic-router/blob/main/website/static/img/vllm.png?raw=true",
231
+ width=150,
232
+ )
233
  with col2:
234
  st.title("๐Ÿง  LLM Semantic Router")
235
+ st.markdown(
236
+ "**Intelligent Router for Mixture-of-Models** | Part of the [vLLM](https://github.com/vllm-project/vllm) ecosystem"
237
+ )
238
 
239
  st.markdown("---")
240
 
 
257
 
258
  # Main content
259
  st.subheader("๐Ÿ“ Input")
260
+
261
+ # Different input UI based on model type
262
+ if model_config["type"] == "dialogue":
263
+ # Dialogue models need query, response, and followup
264
+ demo = model_config["demo"]
265
+ query_input = st.text_input(
266
+ "๐Ÿ—ฃ๏ธ User Query:",
267
+ value=demo["query"],
268
+ placeholder="Enter the original user query...",
269
+ )
270
+ response_input = st.text_input(
271
+ "๐Ÿค– System Response:",
272
+ value=demo["response"],
273
+ placeholder="Enter the system's response...",
274
+ )
275
+ followup_input = st.text_input(
276
+ "๐Ÿ’ฌ User Follow-up:",
277
+ value=demo["followup"],
278
+ placeholder="Enter the user's follow-up message...",
279
+ )
280
+ text_input = None # Not used for dialogue models
281
+ else:
282
+ # Standard text input for other models
283
+ text_input = st.text_area(
284
+ "Enter text to analyze:",
285
+ value=model_config["demo"],
286
+ height=120,
287
+ placeholder="Type your text here...",
288
+ )
289
+ query_input = response_input = followup_input = None
290
 
291
  st.markdown("---")
292
 
293
  # Analyze button
294
  if st.button("๐Ÿ” Analyze", type="primary", use_container_width=True):
295
+ if model_config["type"] == "dialogue":
296
+ if (
297
+ not query_input.strip()
298
+ or not response_input.strip()
299
+ or not followup_input.strip()
300
+ ):
301
+ st.warning("Please fill in all dialogue fields.")
302
+ else:
303
+ with st.spinner("Analyzing..."):
304
+ label, emoji, conf, scores = classify_dialogue(
305
+ query_input,
306
+ response_input,
307
+ followup_input,
308
+ model_config["id"],
309
+ model_config["labels"],
310
+ )
311
+ st.session_state.result = {
312
+ "type": "dialogue",
313
+ "label": label,
314
+ "emoji": emoji,
315
+ "confidence": conf,
316
+ "scores": scores,
317
+ "input": {
318
+ "query": query_input,
319
+ "response": response_input,
320
+ "followup": followup_input,
321
+ },
322
+ }
323
+ elif not text_input.strip():
324
  st.warning("Please enter some text to analyze.")
325
  else:
326
  with st.spinner("Analyzing..."):
 
333
  "label": label,
334
  "emoji": emoji,
335
  "confidence": conf,
336
+ "scores": scores,
337
  }
338
  else:
339
  entities = classify_tokens(text_input, model_config["id"])
340
  st.session_state.result = {
341
  "type": "token",
342
  "entities": entities,
343
+ "text": text_input,
344
  }
345
 
346
  # Display results
 
348
  st.markdown("---")
349
  st.subheader("๐Ÿ“Š Results")
350
  result = st.session_state.result
351
+ if result["type"] in ("sequence", "dialogue"):
352
  col1, col2 = st.columns([1, 1])
353
  with col1:
354
  st.success(f"{result['emoji']} **{result['label']}**")
355
  st.metric("Confidence", f"{result['confidence']:.1%}")
356
  with col2:
357
  st.markdown("**All Scores:**")
358
+ sorted_scores = dict(
359
+ sorted(result["scores"].items(), key=lambda x: x[1], reverse=True)
360
+ )
361
  for k, v in sorted_scores.items():
362
  st.progress(v, text=f"{k}: {v:.1%}")
363
+ elif result["type"] == "token":
364
  entities = result["entities"]
365
  if entities:
366
  st.success(f"Found {len(entities)} PII entity(s)")
367
  for e in entities:
368
  st.markdown(f"- **{e['type']}**: `{e['text']}`")
369
  st.markdown("### Highlighted Text")
370
+ components.html(
371
+ create_highlighted_html(result["text"], entities), height=150
372
+ )
373
  else:
374
  st.info("โœ… No PII detected")
375
 
 
387
  <b>GitHub</b>: <a href="https://github.com/vllm-project/semantic-router">vllm-project/semantic-router</a>
388
  </div>
389
  """,
390
+ unsafe_allow_html=True,
391
  )
392
 
393
 
394
  if __name__ == "__main__":
395
+ main()