Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,9 @@ import pandas as pd # If you're working with DataFrames
|
|
| 11 |
import matplotlib.figure # If you're using matplotlib figures
|
| 12 |
import numpy as np
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
# For Altair charts
|
| 15 |
import altair as alt
|
| 16 |
|
|
@@ -33,19 +36,7 @@ transformers_logger = logging.getLogger("transformers.file_utils")
|
|
| 33 |
transformers_logger.setLevel(logging.INFO) # Set the desired logging level
|
| 34 |
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
import time
|
| 48 |
-
from transformers import load_tool, Agent
|
| 49 |
import torch
|
| 50 |
|
| 51 |
class ToolLoader:
|
|
@@ -62,39 +53,6 @@ class ToolLoader:
|
|
| 62 |
log_response(f"Error loading tool '{tool_name}': {e}")
|
| 63 |
return loaded_tools
|
| 64 |
|
| 65 |
-
class CustomHfAgent(Agent):
|
| 66 |
-
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
|
| 67 |
-
super().__init__(
|
| 68 |
-
chat_prompt_template=chat_prompt_template,
|
| 69 |
-
run_prompt_template=run_prompt_template,
|
| 70 |
-
additional_tools=additional_tools,
|
| 71 |
-
)
|
| 72 |
-
self.url_endpoint = url_endpoint
|
| 73 |
-
self.token = token
|
| 74 |
-
self.input_params = input_params
|
| 75 |
-
|
| 76 |
-
def generate_one(self, prompt, stop):
|
| 77 |
-
headers = {"Authorization": self.token}
|
| 78 |
-
max_new_tokens = self.input_params.get("max_new_tokens", 192)
|
| 79 |
-
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
|
| 80 |
-
inputs = {
|
| 81 |
-
"inputs": prompt,
|
| 82 |
-
"parameters": parameters,
|
| 83 |
-
}
|
| 84 |
-
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
| 85 |
-
|
| 86 |
-
if response.status_code == 429:
|
| 87 |
-
log_response("Getting rate-limited, waiting a tiny bit before trying again.")
|
| 88 |
-
time.sleep(1)
|
| 89 |
-
return self._generate_one(prompt)
|
| 90 |
-
elif response.status_code != 200:
|
| 91 |
-
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
|
| 92 |
-
log_response(response)
|
| 93 |
-
result = response.json()[0]["generated_text"]
|
| 94 |
-
for stop_seq in stop:
|
| 95 |
-
if result.endswith(stop_seq):
|
| 96 |
-
return result[: -len(stop_seq)]
|
| 97 |
-
return result
|
| 98 |
|
| 99 |
def handle_submission(user_message, selected_tools, url_endpoint):
|
| 100 |
|
|
|
|
| 11 |
import matplotlib.figure # If you're using matplotlib figures
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
+
from custom_agent import CustomHfAgent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
# For Altair charts
|
| 18 |
import altair as alt
|
| 19 |
|
|
|
|
| 36 |
transformers_logger.setLevel(logging.INFO) # Set the desired logging level
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
import time
|
|
|
|
| 40 |
import torch
|
| 41 |
|
| 42 |
class ToolLoader:
|
|
|
|
| 53 |
log_response(f"Error loading tool '{tool_name}': {e}")
|
| 54 |
return loaded_tools
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def handle_submission(user_message, selected_tools, url_endpoint):
|
| 58 |
|