|
|
|
|
|
from __future__ import annotations |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List |
|
|
import random |
|
|
|
|
|
@dataclass |
|
|
class TrackedCall: |
|
|
FunctionCallId: str |
|
|
FunctionName: str |
|
|
IsProcessed: bool = False |
|
|
Payload: str = "" |
|
|
|
|
|
class FunctionCallTracker: |
|
|
def __init__(self) -> None: |
|
|
self._by_msg: Dict[str, Dict[str, TrackedCall]] = {} |
|
|
|
|
|
@staticmethod |
|
|
def gen_id() -> str: |
|
|
return f"call_{random.randint(10_000_000, 99_999_999)}" |
|
|
|
|
|
def add(self, message_id: str, fn_name: str, payload: str) -> str: |
|
|
call_id = self.gen_id() |
|
|
self._by_msg.setdefault(message_id, {})[call_id] = TrackedCall(call_id, fn_name, False, payload) |
|
|
return call_id |
|
|
|
|
|
def mark_processed(self, message_id: str, call_id: str, payload: str = "") -> None: |
|
|
m = self._by_msg.get(message_id, {}) |
|
|
if call_id in m: |
|
|
m[call_id].IsProcessed = True |
|
|
if payload: |
|
|
m[call_id].Payload = payload |
|
|
|
|
|
def all_processed(self, message_id: str) -> bool: |
|
|
m = self._by_msg.get(message_id, {}) |
|
|
return bool(m) and all(x.IsProcessed for x in m.values()) |
|
|
|
|
|
def processed_list(self, message_id: str) -> List[TrackedCall]: |
|
|
return list(self._by_msg.get(message_id, {}).values()) |
|
|
|
|
|
def clear(self, message_id: str) -> None: |
|
|
self._by_msg.pop(message_id, None) |
|
|
|