johnbridges commited on
Commit
66c4f69
·
1 Parent(s): be6d3d6

update rabbitmq code to make it more robust

Browse files
Files changed (4) hide show
  1. listener.py +23 -15
  2. oa_server.py +17 -7
  3. rabbit_base.py +16 -12
  4. rabbit_repo.py +39 -13
listener.py CHANGED
@@ -6,7 +6,6 @@ import aio_pika
6
 
7
  Handler = Callable[[Any], Awaitable[None]] # payload is envelope["data"]
8
 
9
- import logging
10
  logger = logging.getLogger(__name__)
11
 
12
 
@@ -31,27 +30,36 @@ class RabbitListenerBase:
31
  q = await self._base.declare_queue_bind(
32
  exchange=exch, queue_name=qname, routing_keys=rks, ttl_ms=ttl
33
  )
34
- await q.consume(self._make_consumer(d["FuncName"]))
 
35
  self._consumers.append(q)
36
 
37
  def _make_consumer(self, func_name: str):
38
  handler = self._handlers.get(func_name)
39
 
40
  async def _on_msg(msg: aio_pika.IncomingMessage):
41
- async with msg.process():
42
- try:
43
- raw_body = msg.body.decode("utf-8", errors="replace")
44
- logger.info(
45
- "Received message for handler '%s': %s",
46
- func_name,
47
- raw_body
48
- )
49
 
 
 
50
  envelope = json.loads(raw_body)
51
- data = envelope.get("data", None)
52
- if handler:
53
- await handler(data)
54
- except Exception as e:
55
- logger.exception("Error processing message for '%s'", func_name)
 
 
 
 
 
 
 
 
 
 
56
 
57
  return _on_msg
 
6
 
7
  Handler = Callable[[Any], Awaitable[None]] # payload is envelope["data"]
8
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
 
30
  q = await self._base.declare_queue_bind(
31
  exchange=exch, queue_name=qname, routing_keys=rks, ttl_ms=ttl
32
  )
33
+ # explicit manual-ack, parity with .NET (autoAck: false)
34
+ await q.consume(self._make_consumer(d["FuncName"]), no_ack=False)
35
  self._consumers.append(q)
36
 
37
  def _make_consumer(self, func_name: str):
38
  handler = self._handlers.get(func_name)
39
 
40
  async def _on_msg(msg: aio_pika.IncomingMessage):
41
+ # manual ack after handler completes; no nack/requeue loops
42
+ try:
43
+ raw_body = msg.body.decode("utf-8", errors="replace")
44
+ logger.info("Received message for handler '%s': %s", func_name, raw_body)
 
 
 
 
45
 
46
+ # safe JSON parse to mirror .NET ConvertToObject (no throw)
47
+ try:
48
  envelope = json.loads(raw_body)
49
+ except Exception:
50
+ logger.exception("Invalid JSON for '%s'", func_name)
51
+ envelope = {"data": None}
52
+
53
+ data = envelope.get("data", None)
54
+
55
+ if handler:
56
+ await handler(data)
57
+ else:
58
+ logger.error("No handler bound for '%s'", func_name)
59
+
60
+ await msg.ack() # ack on success path
61
+ except Exception:
62
+ # match .NET: on exception, do not ack or nack; connection loss will requeue
63
+ logger.exception("Error processing message for '%s'", func_name)
64
 
65
  return _on_msg
oa_server.py CHANGED
@@ -34,12 +34,12 @@ class OpenAIServers:
34
  """
35
 
36
  def __init__(self, publisher: RabbitRepo,
37
- *, chat_backend=None, images_backend=None):
 
38
  self._pub = publisher
39
  self._chat = chat_backend
40
  self._img = images_backend
41
 
42
-
43
  # -------- Chat Completions --------
44
  async def handle_chat_create(self, data: Dict[str, Any]) -> None:
45
  """
@@ -57,10 +57,17 @@ class OpenAIServers:
57
 
58
  try:
59
  async for chunk in self._chat.stream(data):
60
- # CloudEvent-wrapped OpenAI chunk to oa.chat.reply
61
- await self._pub.publish("oa.chat.reply", chunk, routing_key=reply_key)
 
 
 
 
62
  # Optional sentinel
63
- await self._pub.publish("oa.chat.reply", {"object": "stream.end"}, routing_key=reply_key)
 
 
 
64
  except Exception:
65
  logger.exception("oaChatCreate: streaming failed")
66
 
@@ -80,10 +87,14 @@ class OpenAIServers:
80
  try:
81
  b64 = await self._img.generate_b64(data)
82
  resp = {"created": _now(), "data":[{"b64_json": b64}]}
83
- await self._pub.publish("oa.images.reply", resp, routing_key=reply_key)
 
 
 
84
  except Exception:
85
  logger.exception("oaImagesGenerate: generation failed")
86
 
 
87
  # --- at the bottom of oa_server.py ---
88
  # Provide aliases expected by vllm_backend.py
89
  try:
@@ -103,4 +114,3 @@ except NameError:
103
  ImagesBackend = ImageGenerationsBackend # noqa: F821
104
  except Exception:
105
  pass
106
-
 
34
  """
35
 
36
  def __init__(self, publisher: RabbitRepo,
37
+ *, chat_backend: Optional[ChatBackend] = None,
38
+ images_backend: Optional[ImagesBackend] = None):
39
  self._pub = publisher
40
  self._chat = chat_backend
41
  self._img = images_backend
42
 
 
43
  # -------- Chat Completions --------
44
  async def handle_chat_create(self, data: Dict[str, Any]) -> None:
45
  """
 
57
 
58
  try:
59
  async for chunk in self._chat.stream(data):
60
+ try:
61
+ await self._pub.publish("oa.chat.reply", chunk, routing_key=reply_key)
62
+ except Exception:
63
+ logger.exception("oaChatCreate: publish failed")
64
+ break # stop streaming on publish failure
65
+
66
  # Optional sentinel
67
+ try:
68
+ await self._pub.publish("oa.chat.reply", {"object": "stream.end"}, routing_key=reply_key)
69
+ except Exception:
70
+ logger.exception("oaChatCreate: publish sentinel failed")
71
  except Exception:
72
  logger.exception("oaChatCreate: streaming failed")
73
 
 
87
  try:
88
  b64 = await self._img.generate_b64(data)
89
  resp = {"created": _now(), "data":[{"b64_json": b64}]}
90
+ try:
91
+ await self._pub.publish("oa.images.reply", resp, routing_key=reply_key)
92
+ except Exception:
93
+ logger.exception("oaImagesGenerate: publish failed")
94
  except Exception:
95
  logger.exception("oaImagesGenerate: generation failed")
96
 
97
+
98
  # --- at the bottom of oa_server.py ---
99
  # Provide aliases expected by vllm_backend.py
100
  try:
 
114
  ImagesBackend = ImageGenerationsBackend # noqa: F821
115
  except Exception:
116
  pass
 
rabbit_base.py CHANGED
@@ -8,11 +8,10 @@ import aio_pika
8
 
9
  from config import settings
10
 
11
- ExchangeResolver = Callable[[str], str] # exchangeName -> exchangeType
12
-
13
- import logging
14
  logger = logging.getLogger(__name__)
15
 
 
16
  def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
17
  if isinstance(val, str):
18
  name = val.upper()
@@ -24,6 +23,7 @@ def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
24
  pass
25
  return aio_pika.ExchangeType.TOPIC
26
 
 
27
  def _parse_amqp_url(url: str) -> dict:
28
  parts = urlsplit(url)
29
  return {
@@ -46,7 +46,6 @@ class RabbitBase:
46
  lambda _: settings.RABBIT_EXCHANGE_TYPE
47
  )
48
 
49
- # -------- Status helpers --------
50
  def is_connected(self) -> bool:
51
  return bool(
52
  self._conn and not self._conn.is_closed and
@@ -68,15 +67,12 @@ class RabbitBase:
68
  self._conn = None
69
  logger.info("AMQP connection closed")
70
 
71
- # -------- Core ops --------
72
  async def connect(self) -> None:
73
- if self._conn and not self._conn.is_closed:
74
- if self._chan and not self._chan.is_closed:
75
- return
76
 
77
  conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))
78
 
79
- # Log connection target (mask password)
80
  safe_target = {
81
  "scheme": conn_kwargs["scheme"],
82
  "host": conn_kwargs["host"],
@@ -87,7 +83,6 @@ class RabbitBase:
87
  }
88
  logger.info("AMQP connect -> %s", json.dumps(safe_target))
89
 
90
- # TLS (intentionally disabling verification if requested)
91
  ssl_ctx = None
92
  if conn_kwargs.get("ssl"):
93
  ssl_ctx = ssl.create_default_context()
@@ -104,10 +99,15 @@ class RabbitBase:
104
  virtualhost=conn_kwargs["virtualhost"],
105
  ssl=conn_kwargs["ssl"],
106
  ssl_context=ssl_ctx,
 
 
 
107
  )
108
  logger.info("AMQP connection established")
 
109
  self._chan = await self._conn.channel()
110
  logger.info("AMQP channel created")
 
111
  await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
112
  logger.info("AMQP QoS set (prefetch=%s)", settings.RABBIT_PREFETCH)
113
  except Exception:
@@ -117,9 +117,13 @@ class RabbitBase:
117
  async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
118
  await self.connect()
119
  if name in self._exchanges:
120
- return self._exchanges[name]
 
 
 
 
121
 
122
- ex_type_str = self._exchange_type_resolver(name) # e.g. "direct"
123
  ex_type = _normalize_exchange_type(ex_type_str)
124
 
125
  try:
 
8
 
9
  from config import settings
10
 
11
+ ExchangeResolver = Callable[[str], str]
 
 
12
  logger = logging.getLogger(__name__)
13
 
14
+
15
  def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
16
  if isinstance(val, str):
17
  name = val.upper()
 
23
  pass
24
  return aio_pika.ExchangeType.TOPIC
25
 
26
+
27
  def _parse_amqp_url(url: str) -> dict:
28
  parts = urlsplit(url)
29
  return {
 
46
  lambda _: settings.RABBIT_EXCHANGE_TYPE
47
  )
48
 
 
49
  def is_connected(self) -> bool:
50
  return bool(
51
  self._conn and not self._conn.is_closed and
 
67
  self._conn = None
68
  logger.info("AMQP connection closed")
69
 
 
70
  async def connect(self) -> None:
71
+ if self._conn and not self._conn.is_closed and self._chan and not self._chan.is_closed:
72
+ return
 
73
 
74
  conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))
75
 
 
76
  safe_target = {
77
  "scheme": conn_kwargs["scheme"],
78
  "host": conn_kwargs["host"],
 
83
  }
84
  logger.info("AMQP connect -> %s", json.dumps(safe_target))
85
 
 
86
  ssl_ctx = None
87
  if conn_kwargs.get("ssl"):
88
  ssl_ctx = ssl.create_default_context()
 
99
  virtualhost=conn_kwargs["virtualhost"],
100
  ssl=conn_kwargs["ssl"],
101
  ssl_context=ssl_ctx,
102
+ heartbeat=60, # keepalive during long CPU work
103
+ timeout=30,
104
+ client_properties={"connection_name": "hf_backend_publisher"},
105
  )
106
  logger.info("AMQP connection established")
107
+
108
  self._chan = await self._conn.channel()
109
  logger.info("AMQP channel created")
110
+
111
  await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
112
  logger.info("AMQP QoS set (prefetch=%s)", settings.RABBIT_PREFETCH)
113
  except Exception:
 
117
  async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
118
  await self.connect()
119
  if name in self._exchanges:
120
+ ex = self._exchanges[name]
121
+ if ex.channel and not ex.channel.is_closed:
122
+ return ex
123
+ # drop stale cache and recreate
124
+ self._exchanges.pop(name, None)
125
 
126
+ ex_type_str = self._exchange_type_resolver(name)
127
  ex_type = _normalize_exchange_type(ex_type_str)
128
 
129
  try:
rabbit_repo.py CHANGED
@@ -1,38 +1,65 @@
1
  # rabbit_repo.py
2
  import uuid
 
 
3
  from typing import Any, Optional
4
 
 
5
  import aio_pika
6
 
7
  from config import settings
8
  from models import CloudEvent
9
  from rabbit_base import RabbitBase
10
  from utils import to_json, json_compress_str
11
- import logging
12
  logger = logging.getLogger(__name__)
13
 
 
14
  class RabbitRepo(RabbitBase):
15
  def __init__(self, external_source: str):
16
  super().__init__(exchange_type_resolver=self._resolve_type)
17
  self._source = external_source
18
 
19
  def _resolve_type(self, exch: str) -> str:
20
- # First check for oa.* exchanges
21
  if exch.lower().startswith("oa."):
22
- return "direct" # Default for oa.* exchanges
23
-
24
- # Then check EXCHANGE_TYPES if present
25
  if hasattr(settings, 'EXCHANGE_TYPES') and settings.EXCHANGE_TYPES:
26
- matches = [k for k in settings.EXCHANGE_TYPES.keys()
27
- if exch.lower().startswith(k.lower())]
28
  if matches:
29
  return settings.EXCHANGE_TYPES[max(matches, key=len)]
30
-
31
- # Default fallback
32
  return "fanout"
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  async def publish(self, exchange: str, obj: Any, routing_key: str = "") -> None:
35
- ex = await self.ensure_exchange(exchange)
36
  payload = obj if not hasattr(obj, "model_dump") else obj.model_dump(by_alias=True)
37
  evt = CloudEvent.wrap(
38
  event_id=str(uuid.uuid4()),
@@ -41,7 +68,7 @@ class RabbitRepo(RabbitBase):
41
  data=payload,
42
  )
43
  body = evt.model_dump_json(exclude_none=True).encode("utf-8")
44
- await ex.publish(aio_pika.Message(body=body), routing_key=routing_key)
45
 
46
  async def publish_jsonz(
47
  self,
@@ -50,7 +77,6 @@ class RabbitRepo(RabbitBase):
50
  routing_key: str = "",
51
  with_id: Optional[str] = None,
52
  ) -> str:
53
- ex = await self.ensure_exchange(exchange)
54
  payload = obj if not hasattr(obj, "model_dump") else obj.model_dump(by_alias=True)
55
  datajson = to_json(payload)
56
  datajsonZ = json_compress_str(datajson)
@@ -63,5 +89,5 @@ class RabbitRepo(RabbitBase):
63
  data=wrapped,
64
  )
65
  body = evt.model_dump_json(exclude_none=True).encode("utf-8")
66
- await ex.publish(aio_pika.Message(body=body), routing_key=routing_key)
67
  return datajsonZ
 
1
  # rabbit_repo.py
2
  import uuid
3
+ import asyncio
4
+ import logging
5
  from typing import Any, Optional
6
 
7
+ import aiormq
8
  import aio_pika
9
 
10
  from config import settings
11
  from models import CloudEvent
12
  from rabbit_base import RabbitBase
13
  from utils import to_json, json_compress_str
14
+
15
  logger = logging.getLogger(__name__)
16
 
17
+
18
  class RabbitRepo(RabbitBase):
19
  def __init__(self, external_source: str):
20
  super().__init__(exchange_type_resolver=self._resolve_type)
21
  self._source = external_source
22
 
23
  def _resolve_type(self, exch: str) -> str:
 
24
  if exch.lower().startswith("oa."):
25
+ return "direct"
 
 
26
  if hasattr(settings, 'EXCHANGE_TYPES') and settings.EXCHANGE_TYPES:
27
+ matches = [k for k in settings.EXCHANGE_TYPES.keys()
28
+ if exch.lower().startswith(k.lower())]
29
  if matches:
30
  return settings.EXCHANGE_TYPES[max(matches, key=len)]
 
 
31
  return "fanout"
32
 
33
+ async def _publish_with_retry(self, exchange: str, body: bytes, routing_key: str = "") -> None:
34
+ attempts, delay = 0, 0.5
35
+ while True:
36
+ try:
37
+ ex = await self.ensure_exchange(exchange)
38
+ msg = aio_pika.Message(
39
+ body=body,
40
+ delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
41
+ )
42
+ await ex.publish(msg, routing_key=routing_key)
43
+ return
44
+ except (asyncio.CancelledError,
45
+ aiormq.exceptions.ChannelInvalidStateError,
46
+ aiormq.exceptions.ConnectionClosed,
47
+ aio_pika.exceptions.AMQPError,
48
+ RuntimeError) as e:
49
+ attempts += 1
50
+ logger.warning("publish failed attempt=%d exchange=%s rk=%s err=%r",
51
+ attempts, exchange, routing_key, e)
52
+ try:
53
+ await self.close()
54
+ except Exception:
55
+ pass
56
+ if attempts >= 5:
57
+ logger.exception("publish giving up after %d attempts", attempts)
58
+ raise
59
+ await asyncio.sleep(delay)
60
+ delay = min(delay * 2, 5.0)
61
+
62
  async def publish(self, exchange: str, obj: Any, routing_key: str = "") -> None:
 
63
  payload = obj if not hasattr(obj, "model_dump") else obj.model_dump(by_alias=True)
64
  evt = CloudEvent.wrap(
65
  event_id=str(uuid.uuid4()),
 
68
  data=payload,
69
  )
70
  body = evt.model_dump_json(exclude_none=True).encode("utf-8")
71
+ await self._publish_with_retry(exchange, body, routing_key)
72
 
73
  async def publish_jsonz(
74
  self,
 
77
  routing_key: str = "",
78
  with_id: Optional[str] = None,
79
  ) -> str:
 
80
  payload = obj if not hasattr(obj, "model_dump") else obj.model_dump(by_alias=True)
81
  datajson = to_json(payload)
82
  datajsonZ = json_compress_str(datajson)
 
89
  data=wrapped,
90
  )
91
  body = evt.model_dump_json(exclude_none=True).encode("utf-8")
92
+ await self._publish_with_retry(exchange, body, routing_key)
93
  return datajsonZ