File size: 19,998 Bytes
f748a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
import sqlite3
import contextlib
import json
from http.server import BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
import traceback
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple
import os
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from together import Together
import numpy as np
from collections import defaultdict

app = FastAPI(title="Knowledge Graph API")

# Enable CORS for frontend access
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Database configuration - UPDATE THESE PATHS
DATABASE_CONFIG = {
    "triplets_db": "triplets_new.db",  
    "definitions_db": "relations_new.db",
    "news_db": "cnnhealthnews2.db",
    "triplets_table": "triplets", 
    "definitions_table": "relations", 
    "head_column": "head_entity",
    "relation_column": "relation", 
    "tail_column": "tail_entity",
    "definition_column": "definition",
    "link_column": "link",
    "title_column": "column",
    "content_column": "content"
}

class GraphNode(BaseModel):
    id: str
    label: str
    type: str = "entity"

class GraphEdge(BaseModel):
    source: str
    target: str
    relation: str
    definition: Optional[str] = None

class GraphData(BaseModel):
    nodes: List[GraphNode]
    edges: List[GraphEdge]

class TripletData(BaseModel):
    head: str
    relation: str
    tail: str

class RelationDefinition(BaseModel):
    relation: str
    definition: str

class RetrieveTripletsResponse(BaseModel):
    triplets: List[TripletData]
    relations: List[RelationDefinition]

class NewsItem(BaseModel):
    url: str
    content: str
    preview: str
    title: str

class QueryRequest(BaseModel):
    query: str
    
class QueryResponse(BaseModel):
    answer: str
    triplets: List[TripletData]
    relations: List[RelationDefinition]
    news_items: List[NewsItem]
    graph_data: GraphData
    
class ExtractedInformationNews(BaseModel):
    extracted_information: str = Field(description="Extracted information")
    links: list = Field(description="citation links")

class ExtractedInformation(BaseModel):
    extracted_information: str = Field(description="Extracted information")

@contextlib.contextmanager
def get_triplets_db():
    conn = None
    try:
        conn = sqlite3.connect(DATABASE_CONFIG["triplets_db"])
        yield conn
    finally:
        if conn:
            conn.close()

@contextlib.contextmanager
def get_news_db():
    conn = None
    try:
        conn = sqlite3.connect(DATABASE_CONFIG["news_db"])
        yield conn
    finally:
        if conn:
            conn.close()

@contextlib.contextmanager
def get_definitions_db():
    conn = None
    try:
        conn = safe_connect(DATABASE_CONFIG["definitions_db"])
        yield conn
    finally:
        if conn:
            conn.close()

def retrieve_triplets(query: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
    """
    Args:
        query (str): User query
        
    Returns:
        Tuple containing:
        - List of triplets: [(head, relation, tail), ...]
        - List of relations with definitions: [(relation, definition), ...]
    """
    API_KEY = os.environ.get("TOGETHER_API_KEY")
    client = Together(api_key = API_KEY)
    
    dummy_embeddings = FakeEmbeddings(size=768)
    triplets_store = FAISS.load_local(
    "triplets_index_compressed", dummy_embeddings, allow_dangerous_deserialization=True
    )
    triplets_store.index.nprobe = 100
    triplets_store._normalize_L2 = True
    triplets_store.distance_strategy = DistanceStrategy.COSINE

    response = client.embeddings.create(
      model = "Alibaba-NLP/gte-modernbert-base",
      input = query
    )

    emb = np.array(response.data[0].embedding)
    emb = emb / np.linalg.norm(emb)
    
    related_head_entity = []
    result_triplets = triplets_store.similarity_search_with_score_by_vector(emb, k=100)
    for res, score in result_triplets:
        if score > 0.7:
            related_head_entity.append(res)
            
    try:
        all_triplets = []
        with get_triplets_db() as conn:
            head_col = DATABASE_CONFIG["head_column"]
            rel_col = DATABASE_CONFIG["relation_column"] 
            tail_col = DATABASE_CONFIG["tail_column"]
        
            for head_entity in related_head_entity:
                he = head_entity.page_content
                cursor = conn.cursor()
                cursor.execute("SELECT * FROM triplets WHERE head_entity = (?)", ([he]))
                rows = cursor.fetchall()
                triplets = [(str(row[0]), str(row[1]), str(row[2])) for row in rows]
                all_triplets += triplets
            
        all_relations = []
        relations = [relation for _, relation, _ in all_triplets]
        with get_definitions_db() as conn:
            rel_col = DATABASE_CONFIG["relation_column"] 
            def_col = DATABASE_CONFIG["definition_column"]
        
            for rel in set(relations):
                cursor = conn.cursor()
                cursor.execute("SELECT * FROM relations WHERE relation = (?)", ([rel]))
                rows = cursor.fetchall()
                relation = [(str(row[0]), str(row[1])) for row in rows]
                all_relations += relation

        return all_triplets, all_relations
        
    except Exception as e:
        print(f"Error in retrieve_triplets: {e}")
        return [], []

def retrieve_news(query: str) -> Dict[str, str]:
    """
    Args:
        query (str): User query
        
    Returns: Tuple
        - Related content
        - Links of the related content
    """
    API_KEY = os.environ.get("TOGETHER_API_KEY")
    client = Together(api_key = API_KEY)
    
    dummy_embeddings = FakeEmbeddings(size=768)
    news_store = FAISS.load_local(
        "news_index_compressed", dummy_embeddings, allow_dangerous_deserialization=True
    )
    news_store.index.nprobe = 100
    news_store._normalize_L2 = True
    news_store.distance_strategy = DistanceStrategy.COSINE
    
    news_store._normalize_L2 = True
    news_store.distance_strategy = DistanceStrategy.COSINE

    response = client.embeddings.create(
      model = "Alibaba-NLP/gte-modernbert-base",
      input = query
    )

    emb = np.array(response.data[0].embedding)
    emb = emb / np.linalg.norm(emb)

    related_news_content = []
    result_news= news_store.similarity_search_with_score_by_vector(emb, k=500)
    for res, score in result_news:
        if score > 0.7:
            print(score)
            related_news_content.append(res)
    
    news_dict = defaultdict(list)
    links = [res.metadata["link"] for res in related_news_content]
    for idx, link in enumerate(links):
        news_dict[link].append(related_news_content[idx].page_content)
        
    content_only = [". ".join(sentences) for sentences in news_dict.values()]
    
    return content_only, links


def extract_information_from_triplets(query: str,
                                      triplets: List[Tuple[str, str, str]], 
                                      relations: List[Tuple[str, str]]) -> str:
    """
    REPLACE THIS FUNCTION WITH YOUR ACTUAL IMPLEMENTATION
    
    Args:
        triplets: List of triplets from retrieve_triplets
        relations: List of relation definitions from retrieve_triplets
        
    Returns:
        str: Extracted information from triplets
    """
    system_prompt = f'''Given a a list of relational triplets and a list of relation and its definition. Extract the information from the triplets to answer query question.
    If there is no related or useful information can be extracted from the triplets to answer the query question, inform "No related information found."
    Give the output in paragraphs form narratively, you can explain the reason behind your answer in detail."
    '''
    
    user_prompt = f'''
    query question: {query}
    list of triplets: {triplets}
    list of relations and their definition: {relations}
    extracted information:
    '''

    API_KEY = os.environ.get("TOGETHER_API_KEY")
    client = Together(api_key = API_KEY)

    response = client.chat.completions.create(
        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        temperature = 0,
        messages=[{
            "role": "system",
            "content": [
                {"type": "text", "text":system_prompt}
            ]
        },
                {
           "role": "user",
           "content": [
               {"type": "text", "text":user_prompt},
            ]
        }]
    )
    
    return response.choices[0].message.content

def extract_information_from_news(query: str,
                                  news_list: Dict[str, str]) -> Tuple[str, List[str]]:
   """   
    Args:
        news_list: List from retrieve_news
        
    Returns:
        Extracted information string
    """
    system_prompt = f'''Given a list of some information related to the query, extract all important information from the list to answer query question.
    Every item in the list represent one information, if the information is ambiguous (e.g. contains unknown pronoun to which it refers), do not use that information to answer the query.
    You don't have to use all the information, only use the information that has clarity and a good basis, but try to use as many information as possible.
    If there is no related or useful information can be extracted from the news information to answer the query question, write "No related information found." as the extracted_information output.
    Give the extracted_information output in paragraphs form detailedly.
    The output must be in this form: {{"extracted_information": <output paragraphs>}}
    '''
    
    user_prompt = f'''
    query: {query}
    news list: {news_list}
    output:
    '''

    response = client.chat.completions.create(
       model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
       response_format={
            "type": "json_schema",
            "schema": ExtractedInformation.model_json_schema(),
       },
       temperature = 0,
       messages=[{
           "role": "system",
           "content": [
               {"type": "text", "text":system_prompt}
           ]
       },
                {
           "role": "user",
           "content": [
               {"type": "text", "text":user_prompt},
           ]
       }]
    )
    response = json.loads(response.choices[0].message.content)
    info = response['extracted_information']
    
    return info

def extract_information(query:str, triplet_info: str, news_info: str, language:str) -> str:
    """   
    Args:
        triplet_info: Information extracted from triplets
        news_info: Information extracted from news
        
    Returns:
        str: Final answer for the user
    """
    client = Together(api_key = API_KEY)
    system_prompt = f'''Given information from two sources, combine the information and make a comprehensive and informative paragraph that answer the query.
    Make sure the output paragraph includes all crucial information and given in detail.
    If there is no related or useful information can be extracted from the triplets to answer the query question, inform "No related information found."
    Remember this paragraph will be shown to user, so make sure it is based on facts and data, also use appropriate language.
    The output must be in this form and in {language} language: {{"extracted_information": <output paragraphs>}}
    '''
    
    user_prompt = f'''
    query: {query}
    first source: {triplet_info}
    second source: {news_info}
    extracted information:
    '''

    response = client.chat.completions.create(
       model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
       response_format={
            "type": "json_schema",
            "schema": ExtractedInformation.model_json_schema(),
       },
       temperature = 0,
       messages=[{
           "role": "system",
           "content": [
               {"type": "text", "text":system_prompt}
           ]
       },
                {
           "role": "user",
           "content": [
               {"type": "text", "text":user_prompt},
           ]
       }]
    )
    
    response = json.loads(response.choices[0].message.content)
    answer = response["extracted_information"]
    return answer

def news_preview(links: list[str]) -> Tuple[str, str, str]:
    try:
        preview_contents = []
        with get_news_db() as conn:
            for i in links:
                cursor = conn.cursor()
                cursor.execute("SELECT link, title, content FROM CNNHEALTHNEWS2 WHERE link = (?)", ([i]))
                rows = cursor.fetchall()
                prevs = [(str(row[0]), str(row[1]), str(row[2])) for row in rows]
                preview_contents += prevs

        return preview_contents
    
    except Exception as e:
        print(f"Error in news_preview: {e}")
        return ("", "", "")

class Language(BaseModel):
    query: str = Field(description="Translated query")
    language: str = Field(description="Query's language")
    
def query_language(query):
    system_prompt = f'''Your task is to determine what language the question is written in and translate it to english if it is not in English.
    The output must be in this form: {{query: <translated query>, language: <query's language>}}
    '''
    
    user_prompt = f'''
    query: {query}
    output:
    '''
    client = Together(api_key = API_KEY)
    
    response = client.chat.completions.create(
       model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
       response_format={
            "type": "json_schema",
            "schema": Language.model_json_schema(),
       },
       temperature = 0,
       messages=[{
           "role": "system",
           "content": [
               {"type": "text", "text":system_prompt}
           ]
       },
                {
           "role": "user",
           "content": [
               {"type": "text", "text":user_prompt},
           ]
       }])

    return json.loads(response.choices[0].message.content)

#API ENDPOINTS

@app.post("/api/query", response_model=QueryResponse)
def process_query(request: QueryRequest):
    """Process user query and return comprehensive response"""
    try:
        # Step 1: Retrieve triplets
        query = request.query
        query = query_language(query)
        
        triplets_data, relations_data = retrieve_triplets(query['query'])
        
        # Step 2: Retrieve news
        news_list, news_links = retrieve_news(query['query'])
        
        # Step 3: Extract information from triplets
        triplet_info = extract_information_from_triplets(query['query'], triplets_data, relations_data)
        
        # Step 4: Extract information from news
        news_info = extract_information_from_news(query['query'], news_list)
        
        # Step 5: Generate final answer
        final_answer = extract_information(query['query'], triplet_info, news_info, query['language'])
        
        # Convert triplets to response format
        triplets = [TripletData(head=t[0], relation=t[1], tail=t[2]) for t in triplets_data]
        relations = [RelationDefinition(relation=r[0], definition=r[1]) for r in relations_data]
        
        # Convert news to response format with previews
        news_prev = news_preview(news_links)
        news_items = []
        for url, title, content in news_prev:
            preview = content[:300] + "..." if len(content) > 300 else content
            news_items.append(NewsItem(
                url=url,
                content=content,
                preview=preview,
                title=title
            ))
        
        # Create mini graph data for visualization
        nodes_set = set()
        edges = []
        
        for triplet in triplets_data:
            head, relation, tail = triplet
            nodes_set.add(head)
            nodes_set.add(tail)
            
            # Find definition for this relation
            definition = "No definition available"
            for rel, def_text in relations_data:
                if rel == relation:
                    definition = def_text
                    break
            
            edges.append(GraphEdge(
                source=head,
                target=tail,
                relation=relation,
                definition=definition
            ))
        
        nodes = [GraphNode(id=node, label=node) for node in nodes_set]
        graph_data = GraphData(nodes=nodes, edges=edges)
        
        return QueryResponse(
            answer=final_answer,
            triplets=triplets,
            relations=relations,
            news_items=news_items,
            graph_data=graph_data
        )
        
    except Exception as e:
        print(f"Error in process_query: {e}")
        raise HTTPException(status_code=500, detail=f"Query processing failed: {str(e)}")

@app.get("/api/graph", response_model=GraphData)
def get_graph_data(
    search: Optional[str] = None,
    triplets_db: sqlite3.Connection = Depends(get_triplets_connection),
    definitions_db: sqlite3.Connection = Depends(get_definitions_connection)
):
    """Get complete graph data with nodes and edges."""
    
    try:
        # Build dynamic query based on configuration
        table = DATABASE_CONFIG["triplets_table"]
        head_col = DATABASE_CONFIG["head_column"]
        rel_col = DATABASE_CONFIG["relation_column"] 
        tail_col = DATABASE_CONFIG["tail_column"]
        
        base_query = f"SELECT {head_col}, {rel_col}, {tail_col} FROM {table}"
        params = []
        
        if search:
            base_query += f" WHERE {head_col} LIKE ? OR {tail_col} LIKE ? OR {rel_col} LIKE ?"
            search_term = f"%{search}%"
            params = [search_term, search_term, search_term]
        
        base_query += " LIMIT 1000"
        
        # Get triplets
        cursor = triplets_db.execute(base_query, params)
        triplets = cursor.fetchall()

        with get_definitions_db() as conn:
            # Get definitions
            def_table = DATABASE_CONFIG["definitions_table"]
            def_col = DATABASE_CONFIG["definition_column"]
            rel_col_def = DATABASE_CONFIG["relation_column"]
            
            def_cursor = conn.execute(f"SELECT {rel_col_def}, {def_col} FROM {def_table}")
            definitions = {row[0]: row[1] for row in def_cursor.fetchall()}
        
        # Build nodes and edges
        nodes_set = set()
        edges = []
        
        for triple in triplets:
            head = triple[0]
            relation = triple[1]
            tail = triple[2]
            
            # Add entities to nodes set
            nodes_set.add(head)
            nodes_set.add(tail)
            
            # Create edge with definition
            edge = GraphEdge(
                source=head,
                target=tail,
                relation=relation,
                definition=definitions.get(relation, "No definition available")
            )
            edges.append(edge)
        
        # Convert nodes set to list of GraphNode objects
        nodes = [GraphNode(id=node, label=node) for node in nodes_set]
        
        return GraphData(nodes=nodes, edges=edges)
        
    except Exception as e:
        print(f"Error in get_graph_data: {e}")
        raise HTTPException(status_code=500, detail=f"Database query failed: {str(e)}")

if __name__ == "__main__":
    print("Starting Knowledge Graph API...")
    print(f"Triplets DB: {DATABASE_CONFIG['triplets_db']}")
    print(f"Definitions DB: {DATABASE_CONFIG['definitions_db']}")
    
    import uvicorn
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)