Commit
·
bf292d9
1
Parent(s):
e846827
init commit
Browse files- app.py +92 -0
- cloud_event.py +23 -0
- config.py +22 -0
- listener.py +47 -0
- rabbit_base.py +40 -0
- rabbit_repo.py +34 -0
- runners/base.py +17 -0
- runners/service.py +127 -0
- utils.py +11 -0
app.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
|
| 5 |
+
from app.config import settings
|
| 6 |
+
from app.listener import RabbitListenerBase
|
| 7 |
+
from app.rabbit_repo import RabbitRepo
|
| 8 |
+
from app.service import LLMService
|
| 9 |
+
from app.runners.base import ILLMRunner
|
| 10 |
+
|
| 11 |
+
# --- Runner factory (stub) ---
|
| 12 |
+
class EchoRunner(ILLMRunner):
|
| 13 |
+
Type = "EchoRunner"
|
| 14 |
+
async def StartProcess(self, llmServiceObj: dict): pass
|
| 15 |
+
async def RemoveProcess(self, sessionId: str): pass
|
| 16 |
+
async def StopRequest(self, sessionId: str): pass
|
| 17 |
+
async def SendInputAndGetResponse(self, llmServiceObj: dict):
|
| 18 |
+
# Emits a message back (you can choose queue names per your topology)
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
async def runner_factory(llmServiceObj: dict) -> ILLMRunner:
|
| 22 |
+
# Use llmServiceObj["LLMRunnerType"] to instantiate different runners
|
| 23 |
+
return EchoRunner()
|
| 24 |
+
|
| 25 |
+
# --- Publisher and Service ---
|
| 26 |
+
publisher = RabbitRepo(external_source="https://space.external") # put your ExternalUrl if you have one
|
| 27 |
+
service = LLMService(publisher, runner_factory)
|
| 28 |
+
|
| 29 |
+
# --- Handlers mapping .NET FuncName -> service method ---
|
| 30 |
+
async def h_start(data): await service.StartProcess(data or {})
|
| 31 |
+
async def h_user(data): await service.UserInput(data or {})
|
| 32 |
+
async def h_remove(data): await service.RemoveSession(data or {})
|
| 33 |
+
async def h_stop(data): await service.StopRequest(data or {})
|
| 34 |
+
async def h_qir(data): await service.QueryIndexResult(data or {})
|
| 35 |
+
async def h_getreg(data): await service.GetFunctionRegistry(False)
|
| 36 |
+
async def h_getreg_f(data): await service.GetFunctionRegistry(True)
|
| 37 |
+
|
| 38 |
+
handlers = {
|
| 39 |
+
"llmStartSession": h_start,
|
| 40 |
+
"llmUserInput": h_user,
|
| 41 |
+
"llmRemoveSession": h_remove,
|
| 42 |
+
"llmStopRequest": h_stop,
|
| 43 |
+
"queryIndexResult": h_qir,
|
| 44 |
+
"getFunctionRegistry": h_getreg,
|
| 45 |
+
"getFunctionRegistryFiltered": h_getreg_f,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
listener = RabbitListenerBase(service_id=settings.SERVICE_ID, handlers=handlers)
|
| 49 |
+
|
| 50 |
+
# Declarations mirror your C# InitRabbitMQObjs()
|
| 51 |
+
DECLS = [
|
| 52 |
+
{"ExchangeName":"llmStartSession"+settings.SERVICE_ID, "FuncName":"llmStartSession", "MessageTimeout":600000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 53 |
+
{"ExchangeName":"llmUserInput"+settings.SERVICE_ID, "FuncName":"llmUserInput", "MessageTimeout":600000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 54 |
+
{"ExchangeName":"llmRemoveSession"+settings.SERVICE_ID, "FuncName":"llmRemoveSession", "MessageTimeout":60000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 55 |
+
{"ExchangeName":"llmStopRequest"+settings.SERVICE_ID, "FuncName":"llmStopRequest", "MessageTimeout":60000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 56 |
+
{"ExchangeName":"queryIndexResult"+settings.SERVICE_ID, "FuncName":"queryIndexResult", "MessageTimeout":60000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 57 |
+
{"ExchangeName":"getFunctionRegistry"+settings.SERVICE_ID, "FuncName":"getFunctionRegistry", "MessageTimeout":60000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 58 |
+
{"ExchangeName":"getFunctionRegistryFiltered"+settings.SERVICE_ID, "FuncName":"getFunctionRegistryFiltered", "MessageTimeout":60000, "RoutingKeys":[settings.RABBIT_ROUTING_KEY]},
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# --- Gradio UI (for smoke test) ---
|
| 62 |
+
async def ping():
|
| 63 |
+
return "ok"
|
| 64 |
+
|
| 65 |
+
with gr.Blocks() as demo:
|
| 66 |
+
gr.Markdown("### LLM Runner (Python) listening on RabbitMQ")
|
| 67 |
+
btn = gr.Button("Ping")
|
| 68 |
+
out = gr.Textbox()
|
| 69 |
+
btn.click(ping, inputs=None, outputs=out)
|
| 70 |
+
|
| 71 |
+
# --- FastAPI mount + lifecycle ---
|
| 72 |
+
app = FastAPI()
|
| 73 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 74 |
+
|
| 75 |
+
@app.get("/health")
|
| 76 |
+
async def health():
|
| 77 |
+
return {"status":"ok"}
|
| 78 |
+
|
| 79 |
+
@app.on_event("startup")
|
| 80 |
+
async def on_start():
|
| 81 |
+
await publisher.connect()
|
| 82 |
+
await service.init()
|
| 83 |
+
await listener.start(DECLS)
|
| 84 |
+
|
| 85 |
+
@app.on_event("shutdown")
|
| 86 |
+
async def on_stop():
|
| 87 |
+
# aio-pika RobustConnection closes on GC; optionally add explicit closes if you add references
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
import uvicorn
|
| 92 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
cloud_event.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from dataclasses import dataclass, asdict
|
| 3 |
+
from datetime import datetime, timezone
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class CloudEvent:
|
| 8 |
+
id: str
|
| 9 |
+
type: str
|
| 10 |
+
source: str
|
| 11 |
+
time: str
|
| 12 |
+
data: Any
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def wrap(obj: Any, *, event_type: str, source: str, id: str) -> bytes:
|
| 16 |
+
evt = CloudEvent(
|
| 17 |
+
id=id,
|
| 18 |
+
type=event_type or (obj.__class__.__name__ if obj is not None else "NullOrEmpty"),
|
| 19 |
+
source=source,
|
| 20 |
+
time=datetime.now(timezone.utc).isoformat(),
|
| 21 |
+
data=obj,
|
| 22 |
+
)
|
| 23 |
+
return json.dumps(asdict(evt), ensure_ascii=False).encode("utf-8")
|
config.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseSettings, AnyUrl
|
| 2 |
+
from typing import Optional, Dict
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
# AMQP
|
| 6 |
+
AMQP_URL: AnyUrl # e.g. amqps://user:pass@host:5671/%2F?heartbeat=30
|
| 7 |
+
RABBIT_INSTANCE_NAME: str = "prod"
|
| 8 |
+
RABBIT_EXCHANGE_TYPE: str = "topic" # match your .NET Type
|
| 9 |
+
RABBIT_ROUTING_KEY: str = "" # match your .NET RoutingKeys ("" ok)
|
| 10 |
+
RABBIT_PREFETCH: int = 1
|
| 11 |
+
|
| 12 |
+
# Service identity
|
| 13 |
+
SERVICE_ID: str = "monitor" # "monitor"|"nmap"|...
|
| 14 |
+
USE_TLS: bool = True
|
| 15 |
+
|
| 16 |
+
# Optional exchange type overrides by prefix, like .NET ExchangeTypes
|
| 17 |
+
EXCHANGE_TYPES: Dict[str, str] = {}
|
| 18 |
+
|
| 19 |
+
class Config:
|
| 20 |
+
case_sensitive = True
|
| 21 |
+
|
| 22 |
+
settings = Settings() # env-driven
|
listener.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Callable, Dict, List, Optional
|
| 3 |
+
import aio_pika
|
| 4 |
+
from .rabbit_base import RabbitBase
|
| 5 |
+
from .config import settings
|
| 6 |
+
|
| 7 |
+
# Maps FuncName -> handler coroutine
|
| 8 |
+
Handler = Callable[[dict], "awaitable[None]"]
|
| 9 |
+
|
| 10 |
+
class RabbitListenerBase(RabbitBase):
|
| 11 |
+
def __init__(self, service_id: str, handlers: Dict[str, Handler]):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self._service_id = service_id
|
| 14 |
+
self._handlers = handlers
|
| 15 |
+
self._consumers: List[aio_pika.abc.AbstractRobustQueue] = []
|
| 16 |
+
|
| 17 |
+
def _qname(self, exchange: str, routing_keys: List[str]) -> str:
|
| 18 |
+
rk_part = "-".join(sorted([rk for rk in (routing_keys or [""]) if rk != ""])) or ""
|
| 19 |
+
suffix = f"-{rk_part}" if rk_part else ""
|
| 20 |
+
return f"{settings.RABBIT_INSTANCE_NAME}-{exchange}{suffix}"
|
| 21 |
+
|
| 22 |
+
async def start(self, declarations: List[dict]):
|
| 23 |
+
"""
|
| 24 |
+
declarations: list of {ExchangeName, FuncName, MessageTimeout, Type?, RoutingKeys?}
|
| 25 |
+
"""
|
| 26 |
+
for d in declarations:
|
| 27 |
+
exch = d["ExchangeName"]
|
| 28 |
+
rks = d.get("RoutingKeys") or [settings.RABBIT_ROUTING_KEY]
|
| 29 |
+
ttl = d.get("MessageTimeout") or None
|
| 30 |
+
q = await self.declare_queue_bind(exchange=exch, queue_name=self._qname(exch, rks), routing_keys=rks, ttl_ms=ttl)
|
| 31 |
+
await q.consume(self._make_consumer(d["FuncName"]))
|
| 32 |
+
self._consumers.append(q)
|
| 33 |
+
|
| 34 |
+
def _make_consumer(self, func_name: str):
|
| 35 |
+
handler = self._handlers.get(func_name)
|
| 36 |
+
async def _on_msg(msg: aio_pika.IncomingMessage):
|
| 37 |
+
async with msg.process():
|
| 38 |
+
try:
|
| 39 |
+
# Expect CloudEvent JSON
|
| 40 |
+
envelope = json.loads(msg.body.decode("utf-8"))
|
| 41 |
+
data = envelope.get("data")
|
| 42 |
+
if handler:
|
| 43 |
+
await handler(data)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
# swallow to avoid nack loops; your logger can capture details
|
| 46 |
+
pass
|
| 47 |
+
return _on_msg
|
rabbit_base.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio, json, uuid
|
| 2 |
+
import aio_pika
|
| 3 |
+
from typing import Callable, Dict, List, Optional
|
| 4 |
+
from .config import settings
|
| 5 |
+
|
| 6 |
+
ExchangeResolver = Callable[[str], str] # exchangeName -> exchangeType
|
| 7 |
+
|
| 8 |
+
class RabbitBase:
|
| 9 |
+
def __init__(self, exchange_type_resolver: Optional[ExchangeResolver] = None):
|
| 10 |
+
self._conn: Optional[aio_pika.RobustConnection] = None
|
| 11 |
+
self._chan: Optional[aio_pika.RobustChannel] = None
|
| 12 |
+
self._exchanges: Dict[str, aio_pika.Exchange] = {}
|
| 13 |
+
self._exchange_type_resolver = exchange_type_resolver or (lambda _: settings.RABBIT_EXCHANGE_TYPE)
|
| 14 |
+
|
| 15 |
+
async def connect(self):
|
| 16 |
+
if self._conn and not self._conn.is_closed:
|
| 17 |
+
return
|
| 18 |
+
self._conn = await aio_pika.connect_robust(str(settings.AMQP_URL))
|
| 19 |
+
self._chan = await self._conn.channel()
|
| 20 |
+
await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
|
| 21 |
+
|
| 22 |
+
async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
|
| 23 |
+
await self.connect()
|
| 24 |
+
if name in self._exchanges:
|
| 25 |
+
return self._exchanges[name]
|
| 26 |
+
ex_type = self._exchange_type_resolver(name)
|
| 27 |
+
ex = await self._chan.declare_exchange(name, getattr(aio_pika.ExchangeType, ex_type), durable=True)
|
| 28 |
+
self._exchanges[name] = ex
|
| 29 |
+
return ex
|
| 30 |
+
|
| 31 |
+
async def declare_queue_bind(self, exchange: str, queue_name: str, routing_keys: List[str], ttl_ms: Optional[int]):
|
| 32 |
+
await self.connect()
|
| 33 |
+
ex = await self.ensure_exchange(exchange)
|
| 34 |
+
args = {}
|
| 35 |
+
if ttl_ms:
|
| 36 |
+
args["x-message-ttl"] = ttl_ms
|
| 37 |
+
q = await self._chan.declare_queue(queue_name, durable=True, exclusive=False, auto_delete=True, arguments=args)
|
| 38 |
+
for rk in routing_keys or [""]:
|
| 39 |
+
await q.bind(ex, rk)
|
| 40 |
+
return q
|
rabbit_repo.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from typing import Any
|
| 3 |
+
from .rabbit_base import RabbitBase
|
| 4 |
+
from .cloud_event import CloudEvent
|
| 5 |
+
from .config import settings
|
| 6 |
+
from .utils import to_json, json_compress_str
|
| 7 |
+
|
| 8 |
+
class RabbitRepo(RabbitBase):
|
| 9 |
+
def __init__(self, external_source: str):
|
| 10 |
+
super().__init__(exchange_type_resolver=self._resolve_type)
|
| 11 |
+
self._external_source = external_source # like SystemUrl.ExternalUrl
|
| 12 |
+
|
| 13 |
+
def _resolve_type(self, exch: str) -> str:
|
| 14 |
+
# longest prefix wins (like your .NET mapping)
|
| 15 |
+
matches = [k for k in settings.EXCHANGE_TYPES.keys() if exch.lower().startswith(k.lower())]
|
| 16 |
+
if matches:
|
| 17 |
+
return settings.EXCHANGE_TYPES[max(matches, key=len)]
|
| 18 |
+
return settings.RABBIT_EXCHANGE_TYPE
|
| 19 |
+
|
| 20 |
+
async def publish(self, exchange: str, obj: Any, routing_key: str = ""):
|
| 21 |
+
ex = await self.ensure_exchange(exchange)
|
| 22 |
+
payload = CloudEvent.wrap(obj, event_type=(obj.__class__.__name__ if obj is not None else "NullOrEmpty"),
|
| 23 |
+
source=self._external_source, id=str(uuid.uuid4()))
|
| 24 |
+
await ex.publish(aio_pika.Message(body=payload), routing_key=routing_key)
|
| 25 |
+
|
| 26 |
+
async def publish_jsonz(self, exchange: str, obj: Any, routing_key: str = "", with_id: str | None = None):
|
| 27 |
+
ex = await self.ensure_exchange(exchange)
|
| 28 |
+
json_str = to_json(obj)
|
| 29 |
+
datajsonZ = json_compress_str(json_str)
|
| 30 |
+
to_send = (datajsonZ, with_id) if with_id else datajsonZ
|
| 31 |
+
payload = CloudEvent.wrap(to_send, event_type=(obj.__class__.__name__ if obj is not None else "NullOrEmpty"),
|
| 32 |
+
source=self._external_source, id=str(uuid.uuid4()))
|
| 33 |
+
await ex.publish(aio_pika.Message(body=payload), routing_key=routing_key)
|
| 34 |
+
return datajsonZ
|
runners/base.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
class ILLMRunner(ABC):
|
| 5 |
+
Type: str = "BaseLLM"
|
| 6 |
+
IsEnabled: bool = True
|
| 7 |
+
IsStateStarting: bool = False
|
| 8 |
+
IsStateFailed: bool = False
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
async def StartProcess(self, llmServiceObj: dict) -> None: ...
|
| 12 |
+
@abstractmethod
|
| 13 |
+
async def RemoveProcess(self, sessionId: str) -> None: ...
|
| 14 |
+
@abstractmethod
|
| 15 |
+
async def StopRequest(self, sessionId: str) -> None: ...
|
| 16 |
+
@abstractmethod
|
| 17 |
+
async def SendInputAndGetResponse(self, llmServiceObj: dict) -> None: ...
|
runners/service.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
from .rabbit_repo import RabbitRepo
|
| 6 |
+
from .config import settings
|
| 7 |
+
from .runners.base import ILLMRunner
|
| 8 |
+
|
| 9 |
+
class LLMService:
|
| 10 |
+
def __init__(self, publisher: RabbitRepo, runner_factory):
|
| 11 |
+
self._pub = publisher
|
| 12 |
+
self._runner_factory = runner_factory
|
| 13 |
+
self._sessions: Dict[str, dict] = {} # sessionId -> {"Runner": ILLMRunner, "FullSessionId": str}
|
| 14 |
+
self._ready = asyncio.Event()
|
| 15 |
+
self._ready.set() # if you have async load, clear and set after
|
| 16 |
+
|
| 17 |
+
async def init(self):
|
| 18 |
+
# If you have history to load, do here then self._ready.set()
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
async def _set_result(self, obj: dict, message: str, success: bool, queue: str, check_system: bool=False):
|
| 22 |
+
obj["ResultMessage"] = message
|
| 23 |
+
obj["ResultSuccess"] = success
|
| 24 |
+
obj["LlmMessage"] = (f"<Success>{message}</Success>" if success else f"<Error>{message}</Error>")
|
| 25 |
+
# mirror your .NET rule (don’t publish for system llm if check_system is True)
|
| 26 |
+
if not (check_system and obj.get("IsSystemLlm")):
|
| 27 |
+
await self._pub.publish(queue, obj)
|
| 28 |
+
|
| 29 |
+
async def StartProcess(self, llmServiceObj: dict):
|
| 30 |
+
session_id = f"{llmServiceObj['RequestSessionId']}_{llmServiceObj['LLMRunnerType']}"
|
| 31 |
+
llmServiceObj["SessionId"] = session_id
|
| 32 |
+
|
| 33 |
+
# wait ready (max ~120s like .NET)
|
| 34 |
+
try:
|
| 35 |
+
await asyncio.wait_for(self._ready.wait(), timeout=120)
|
| 36 |
+
except asyncio.TimeoutError:
|
| 37 |
+
await self._set_result(llmServiceObj, "Timed out waiting for initialization.", False, "llmServiceMessage", True)
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
sess = self._sessions.get(session_id)
|
| 41 |
+
is_runner_null = not sess or not sess.get("Runner")
|
| 42 |
+
|
| 43 |
+
create_new = is_runner_null or sess["Runner"].IsStateFailed
|
| 44 |
+
if create_new:
|
| 45 |
+
if sess and sess.get("Runner"):
|
| 46 |
+
try:
|
| 47 |
+
await sess["Runner"].RemoveProcess(session_id)
|
| 48 |
+
except: pass
|
| 49 |
+
|
| 50 |
+
runner: ILLMRunner = await self._runner_factory(llmServiceObj)
|
| 51 |
+
if not runner.IsEnabled:
|
| 52 |
+
await self._set_result(llmServiceObj, f"{llmServiceObj['LLMRunnerType']} {settings.SERVICE_ID} not started as it is disabled.", True, "llmServiceMessage")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
await self._set_result(llmServiceObj, f"Starting {runner.Type} {settings.SERVICE_ID} Expert", True, "llmServiceMessage", True)
|
| 56 |
+
await runner.StartProcess(llmServiceObj)
|
| 57 |
+
|
| 58 |
+
self._sessions[session_id] = {"Runner": runner, "FullSessionId": session_id}
|
| 59 |
+
if settings.SERVICE_ID == "monitor":
|
| 60 |
+
await self._set_result(llmServiceObj, f"Hi i'm {runner.Type} your Network Monitor Assistant. How can I help you.", True, "llmServiceMessage", True)
|
| 61 |
+
|
| 62 |
+
await self._pub.publish("llmServiceStarted", llmServiceObj)
|
| 63 |
+
|
| 64 |
+
async def RemoveSession(self, llmServiceObj: dict):
|
| 65 |
+
# Behaves like your RemoveAllSessionIdProcesses (prefix match)
|
| 66 |
+
base = llmServiceObj.get("SessionId","").split("_")[0]
|
| 67 |
+
targets = [k for k in self._sessions.keys() if k.startswith(base + "_")]
|
| 68 |
+
msgs = []
|
| 69 |
+
ok = True
|
| 70 |
+
for sid in targets:
|
| 71 |
+
s = self._sessions.get(sid)
|
| 72 |
+
if s and s.get("Runner"):
|
| 73 |
+
try:
|
| 74 |
+
await s["Runner"].RemoveProcess(sid)
|
| 75 |
+
s["Runner"] = None
|
| 76 |
+
msgs.append(sid)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
ok = False
|
| 79 |
+
msgs.append(f"Error {sid}: {e}")
|
| 80 |
+
if ok:
|
| 81 |
+
await self._set_result(llmServiceObj, f"Success: Removed sessions for {' '.join(msgs)}", True, "llmSessionMessage", True)
|
| 82 |
+
else:
|
| 83 |
+
await self._set_result(llmServiceObj, " ".join(msgs), False, "llmServiceMessage")
|
| 84 |
+
|
| 85 |
+
async def StopRequest(self, llmServiceObj: dict):
|
| 86 |
+
sid = llmServiceObj.get("SessionId","")
|
| 87 |
+
s = self._sessions.get(sid)
|
| 88 |
+
if not s or not s.get("Runner"):
|
| 89 |
+
await self._set_result(llmServiceObj, f"Error: Runner missing for session {sid}.", False, "llmServiceMessage")
|
| 90 |
+
return
|
| 91 |
+
await s["Runner"].StopRequest(sid)
|
| 92 |
+
await self._set_result(llmServiceObj, f"Success {s['Runner'].Type} {settings.SERVICE_ID} Assistant output has been halted", True, "llmServiceMessage", True)
|
| 93 |
+
|
| 94 |
+
async def UserInput(self, llmServiceObj: dict):
|
| 95 |
+
sid = llmServiceObj.get("SessionId","")
|
| 96 |
+
s = self._sessions.get(sid)
|
| 97 |
+
if not s or not s.get("Runner"):
|
| 98 |
+
await self._set_result(llmServiceObj, f"Error: SessionId {sid} has no running process.", False, "llmServiceMessage")
|
| 99 |
+
return
|
| 100 |
+
r: ILLMRunner = s["Runner"]
|
| 101 |
+
if r.IsStateStarting:
|
| 102 |
+
await self._set_result(llmServiceObj, "Please wait, the assistant is starting...", False, "llmServiceMessage")
|
| 103 |
+
return
|
| 104 |
+
if r.IsStateFailed:
|
| 105 |
+
await self._set_result(llmServiceObj, "The Assistant is stopped. Try reloading.", False, "llmServiceMessage")
|
| 106 |
+
return
|
| 107 |
+
await r.SendInputAndGetResponse(llmServiceObj)
|
| 108 |
+
# emitter side can push partials directly to queues if desired
|
| 109 |
+
|
| 110 |
+
async def QueryIndexResult(self, queryIndexRequest: dict):
|
| 111 |
+
# Adapted to your behavior: concatenate outputs, publish completion via internal coordinator if needed
|
| 112 |
+
try:
|
| 113 |
+
rag_data = "\n".join([qr.get("Output","") for qr in (queryIndexRequest.get("QueryResults") or [])])
|
| 114 |
+
# You signal _queryCoordinator.CompleteQuery in .NET; here you may forward/publish result…
|
| 115 |
+
# Example: include rag data in a service message to the session
|
| 116 |
+
await self._pub.publish("llmServiceMessage", {
|
| 117 |
+
"ResultSuccess": queryIndexRequest.get("Success", False),
|
| 118 |
+
"ResultMessage": queryIndexRequest.get("Message",""),
|
| 119 |
+
"Data": rag_data,
|
| 120 |
+
})
|
| 121 |
+
except Exception as e:
|
| 122 |
+
await self._pub.publish("llmServiceMessage", {"ResultSuccess": False, "ResultMessage": str(e)})
|
| 123 |
+
|
| 124 |
+
async def GetFunctionRegistry(self, filtered: bool = False):
|
| 125 |
+
# Plug in your registry
|
| 126 |
+
data = {"FunctionCatalogJson": "{}", "Filtered": filtered}
|
| 127 |
+
await self._pub.publish("llmServiceMessage", {"ResultSuccess": True, "ResultMessage": f"Success : Got GetFunctionCatalogJson : {data}"})
|
utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, zlib, base64
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
def to_json(obj: Any) -> str:
|
| 5 |
+
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"))
|
| 6 |
+
|
| 7 |
+
def json_compress_str(s: str) -> str:
|
| 8 |
+
return base64.b64encode(zlib.compress(s.encode("utf-8"), level=6)).decode("ascii")
|
| 9 |
+
|
| 10 |
+
def json_decompress_str(s: str) -> str:
|
| 11 |
+
return zlib.decompress(base64.b64decode(s)).decode("utf-8")
|