File size: 3,963 Bytes
1f725d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import logging
import sys
import json
import re
from src.Blog.models.RouterDecision_model import RouterDecision
from langchain_core.messages import SystemMessage, HumanMessage
from src.Blog.models.State_model import State
from src.Blog.llm import llm
from src.Blog.prompts import ROUTER_SYSTEM
from exception import MyException
from utils.asyncHandler import asyncHandler


@asyncHandler
async def router_node(state: State):
    logging.info("Entering router_node")
    topic = state['topic']
    logging.debug(f"Topic: {topic}")
    
    try:
        try:
            runnable = llm.with_structured_output(RouterDecision)
            decision = await runnable.ainvoke(
                [
                    SystemMessage(content=ROUTER_SYSTEM),
                    HumanMessage(content=f"Topic: {topic}")
                ]
            )
            if decision:
                logging.info(f"Router decision (structured): needs_research={decision.needs_research}, mode={decision.mode}")
                return {
                    "needs_research": decision.needs_research,
                    "mode": decision.mode,
                    "queries": decision.queries,
                }
        except Exception as e:
            logging.warning(f"Structured output failed: {str(e)}. Attempting manual parse.")

        raw_response = await llm.ainvoke(
            [
                SystemMessage(content=ROUTER_SYSTEM + "\n\nCRITICAL: You MUST return a valid JSON object. Do not include any text before or after the JSON."),
                HumanMessage(content=f"Topic: {topic}")
            ]
        )
        content = raw_response.content
        logging.debug(f"Raw LLM content for fallback: {content}")

        json_str = ""
        markdown_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
        if markdown_match:
            json_str = markdown_match.group(1)
        else:
            start = content.find('{')
            end = content.rfind('}')
            if start != -1 and end != -1:
                json_str = content[start:end+1]
        
        if json_str:
            try:
                data = json.loads(json_str)
            except json.JSONDecodeError:
                # Progressive truncation fallback
                success = False
                temp_str = json_str
                while '}' in temp_str:
                    try:
                        data = json.loads(temp_str)
                        success = True
                        break
                    except json.JSONDecodeError:
                        last_brace = temp_str.rfind('}')
                        if last_brace == -1: break
                        temp_str = temp_str[:last_brace]
                
                if not success:
                    raise ValueError("Failed to parse JSON even after structural truncation")

            needs_res = str(data.get("needs_research", "")).lower() in ["true", "1", "yes"]
            
            decision = RouterDecision(
                needs_research=needs_res,
                mode=data.get("mode", "open_book"),
                queries=data.get("queries", [])
            )
            logging.info(f"Router decision (manual): needs_research={decision.needs_research}, mode={decision.mode}")
            return {
                "needs_research": decision.needs_research,
                "mode": decision.mode,
                "queries": decision.queries,
            }

        logging.error("Failed to extract JSON from LLM response")
        raise ValueError("LLM failed to return a valid RouterDecision. Please check prompts or model output.")

    except Exception as e:
        logging.error(f"Error in router_node: {str(e)}")
        raise

def route_next(state: State) -> str:
    needs_research = state.get("needs_research", False)
    logging.info(f"Routing next based on research need: {needs_research}")
    return "research" if needs_research else "orchestrator"