Spaces:
Running
Running
File size: 7,782 Bytes
995d35f 2101fac 5fb65b7 2bdf9a9 8870d4b 5fb65b7 5d0417f f9314a9 995d35f 8870d4b 1ca9acb 2bdf9a9 5fb65b7 f1bceb8 f9314a9 5fb65b7 3f429e3 8870d4b b103d78 c7eddb5 5fb65b7 8870d4b 5fb65b7 e3dd34f 5fb65b7 f9314a9 e3dd34f 5fb65b7 a8f2af1 4137d9f a8f2af1 5fb65b7 f9314a9 5fb65b7 c7eddb5 5fb65b7 a8f2af1 c7eddb5 5fb65b7 e04da09 5fb65b7 c7eddb5 5fb65b7 c7eddb5 8870d4b 5fb65b7 c7eddb5 5fb65b7 8870d4b 5fb65b7 f9314a9 e3dd34f f9314a9 5fb65b7 e3dd34f 5fb65b7 e3dd34f 59ed4d3 e3dd34f c7eddb5 e3dd34f e04da09 5fb65b7 e3dd34f 5fb65b7 e3dd34f e04da09 5fb65b7 e04da09 2bdf9a9 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 995d35f a8f2af1 2bdf9a9 5fb65b7 a8f2af1 8870d4b b103d78 8870d4b 1ca9acb b103d78 5fb65b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from fastapi import FastAPI, Header
import uvicorn
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import pygsheets
import json
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import SystemMessage, AnyMessage
# from langgraph.pregel import RetryPolicy
from langgraph.types import RetryPolicy
import json
from google.oauth2 import service_account
import os
from langchain_groq import ChatGroq
import groq
from datetime import datetime
from fastapi import HTTPException
from langchain_google_genai import ChatGoogleGenerativeAI
from opik.integrations.langchain import OpikTracer
from pytz import timezone
# Load environment variables - for local development
from dotenv import load_dotenv
load_dotenv()
SHEET_URL = os.getenv("SHEET_URL")
GOOGLESHEETS_CREDENTIALS = os.getenv("GOOGLESHEETS_CREDENTIALS")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GROQ_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
GOOGLE_MODEL = "gemini-2.5-flash-lite"
ist_tz = timezone("Asia/Kolkata")
class TransactionParser(BaseModel):
"""This Pydantic class is used to parse the transaction message. The message is taken and the output is structured in a specific format based upon below definitions."""
amount: str = Field(description="The amount of the transaction strictly in decimal format. Do not insert currency symbol.", example="123.45") # type: ignore
dr_or_cr: str = Field(description="Identify if the transaction was debit (spent) or credit (received). Strictly choose one of the values - Debit or Credit")
receiver: str = Field(description="The recipient of the transaction. Identify the Merchant Name from the message text.")
category: str = Field(description="The category of the transaction. The category of the transaction is linked to the Merchant Name. Strictly choose from one the of values - Shopping,EMI,Education,Miscellaneous,Grocery,Utility,House Help,Travel,Transport,Food,Insurance")
transaction_origin: str = Field(description="The origin of the transaction. Provide the card or account number as well.")
class TransactionClassification(BaseModel):
"""This Pydantic class is used to classify the transaction message. The message is taken and the output is structured in a specific format based upon below definition."""
classification: str = Field(description="Classification of the transaction. Strictly choose one of the values - Transaction, OTP, Scheduled, Reminder, Reward_Points_Credit, sweep_in_fd")
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
class Agent:
def __init__(self, model, system=""):
self.system = system
graph = StateGraph(AgentState)
graph.add_node("classify_txn_type", self.classify_txn_type, retry=RetryPolicy(retry_on=[groq.APIConnectionError], max_attempts=5))
graph.add_node("parse_message", self.parse_message, retry=RetryPolicy(retry_on=[groq.APIConnectionError], max_attempts=5))
graph.add_node("write_message", self.write_message)
graph.add_conditional_edges(
"classify_txn_type",
self.check_txn_and_decide,
{True: "parse_message", False: END}
)
graph.add_edge("parse_message", "write_message")
graph.add_edge("write_message", END)
graph.set_entry_point("classify_txn_type")
self.graph = graph.compile()
self.model = model
def classify_txn_type(self, state: AgentState) -> AgentState:
print(f"{datetime.now(ist_tz)}: Classifying transaction type...")
messages = state["messages"]
if self.system:
messages = [SystemMessage(content=self.system)] + messages
message = self.model.with_structured_output(TransactionClassification).invoke(messages)
print(f"{datetime.now(ist_tz)}: Classifying transaction type completed.")
return {"messages": [message]}
def parse_message(self, state: AgentState) -> AgentState:
print(f"{datetime.now(ist_tz)}: Parsing transaction message...")
message = state["messages"][0]#.content
system = """
You are a helpful assistant skilled at parsing transaction messages and providing structured responses.
"""
human = "Categorize the transaction message and provide the output in a structed format: {topic}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
chain = prompt | self.model.with_structured_output(TransactionParser)
result = chain.invoke({"topic": message})
print(f"{datetime.now(ist_tz)}: Parsing transaction message completed.")
return {"messages": [result]}
def write_message(self, state: AgentState) -> AgentState:
print(f"{datetime.now(ist_tz)}: Writing transaction message to Google Sheets...")
result = state["messages"][-1]
SCOPES = ('https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive')
service_account_info = json.loads(GOOGLESHEETS_CREDENTIALS) # type: ignore
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
client = pygsheets.authorize(custom_credentials=credentials)
worksheet = client.open_by_url(SHEET_URL)
wk = worksheet[0]
# Get number of rows in the worksheet
df = wk.get_as_df(start='A1', end='G999') # type: ignore
nrows = df.shape[0]
wk.update_value(f'A{nrows+2}', result.amount) # type: ignore
wk.update_value(f'B{nrows+2}', result.dr_or_cr) # type: ignore
wk.update_value(f'C{nrows+2}', result.receiver) # type: ignore
wk.update_value(f'D{nrows+2}', result.category) # type: ignore
wk.update_value(f'E{nrows+2}', datetime.now(ist_tz).strftime("%Y-%m-%d %H:%M:%S")) # type: ignore
wk.update_value(f'F{nrows+2}', result.transaction_origin) # type: ignore
wk.update_value(f'G{nrows+2}', state["messages"][0]) # type: ignore
print(f"{datetime.now(ist_tz)}: Writing transaction message to Google Sheets completed.")
return {"messages": ["Transaction Completed"]} # type: ignore
def check_txn_and_decide(self, state: AgentState):
try:
result = state['messages'][-1].classification # type: ignore
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid response format from model")
return result == "Transaction"
app = FastAPI()
@app.get("/")
def greetings():
return {"message": "Hello, this is a transaction bot. Please send a POST request to /write_message with the transaction data."}
@app.post("/write_message")
def write_message(data: dict, header: str = Header()):
if header != HF_TOKEN:
raise HTTPException(status_code=400, detail="Invalid header")
prompt = """You are a smart assistant adept at working with bank transaction messages."""
message = data['message']
try:
model = ChatGoogleGenerativeAI(model=GOOGLE_MODEL, callbacks = [OpikTracer()])
transaction_bot = Agent(model, system=prompt)
transaction_bot.graph.invoke({"messages": [message]})
except Exception as e: #fallback model
model = ChatGroq(model=GROQ_MODEL, temperature=1, callbacks = [OpikTracer()])
transaction_bot = Agent(model, system=prompt)
transaction_bot.graph.invoke({"messages": [message]})
return {"message": "Transaction completed successfully"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |