| | import asyncio |
| | import json |
| | import ssl |
| | from threading import Thread |
| |
|
| | from websockets.server import serve |
| |
|
| | from extensions.api.util import ( |
| | build_parameters, |
| | try_start_cloudflared, |
| | with_api_lock |
| | ) |
| | from modules import shared |
| | from modules.chat import generate_chat_reply |
| | from modules.text_generation import generate_reply |
| | from modules.logging_colors import logger |
| |
|
| | PATH = '/api/v1/stream' |
| |
|
| |
|
| | @with_api_lock |
| | async def _handle_stream_message(websocket, message): |
| | message = json.loads(message) |
| |
|
| | prompt = message['prompt'] |
| | generate_params = build_parameters(message) |
| | stopping_strings = generate_params.pop('stopping_strings') |
| | generate_params['stream'] = True |
| |
|
| | generator = generate_reply( |
| | prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) |
| |
|
| | |
| | skip_index = 0 |
| | message_num = 0 |
| |
|
| | for a in generator: |
| | to_send = a[skip_index:] |
| | if to_send is None or chr(0xfffd) in to_send: |
| | continue |
| |
|
| | await websocket.send(json.dumps({ |
| | 'event': 'text_stream', |
| | 'message_num': message_num, |
| | 'text': to_send |
| | })) |
| |
|
| | await asyncio.sleep(0) |
| | skip_index += len(to_send) |
| | message_num += 1 |
| |
|
| | await websocket.send(json.dumps({ |
| | 'event': 'stream_end', |
| | 'message_num': message_num |
| | })) |
| |
|
| |
|
| | @with_api_lock |
| | async def _handle_chat_stream_message(websocket, message): |
| | body = json.loads(message) |
| |
|
| | user_input = body['user_input'] |
| | generate_params = build_parameters(body, chat=True) |
| | generate_params['stream'] = True |
| | regenerate = body.get('regenerate', False) |
| | _continue = body.get('_continue', False) |
| |
|
| | generator = generate_chat_reply( |
| | user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) |
| |
|
| | message_num = 0 |
| | for a in generator: |
| | await websocket.send(json.dumps({ |
| | 'event': 'text_stream', |
| | 'message_num': message_num, |
| | 'history': a |
| | })) |
| |
|
| | await asyncio.sleep(0) |
| | message_num += 1 |
| |
|
| | await websocket.send(json.dumps({ |
| | 'event': 'stream_end', |
| | 'message_num': message_num |
| | })) |
| |
|
| |
|
| | async def _handle_connection(websocket, path): |
| |
|
| | if path == '/api/v1/stream': |
| | async for message in websocket: |
| | await _handle_stream_message(websocket, message) |
| |
|
| | elif path == '/api/v1/chat-stream': |
| | async for message in websocket: |
| | await _handle_chat_stream_message(websocket, message) |
| |
|
| | else: |
| | print(f'Streaming api: unknown path: {path}') |
| | return |
| |
|
| |
|
| | async def _run(host: str, port: int): |
| | ssl_certfile = shared.args.ssl_certfile |
| | ssl_keyfile = shared.args.ssl_keyfile |
| | ssl_verify = True if (ssl_keyfile and ssl_certfile) else False |
| | if ssl_verify: |
| | context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
| | context.load_cert_chain(ssl_certfile, ssl_keyfile) |
| | else: |
| | context = None |
| |
|
| | async with serve(_handle_connection, host, port, ping_interval=None, ssl=context): |
| | await asyncio.Future() |
| |
|
| |
|
| | def _run_server(port: int, share: bool = False, tunnel_id=str): |
| | address = '0.0.0.0' if shared.args.listen else '127.0.0.1' |
| | ssl_certfile = shared.args.ssl_certfile |
| | ssl_keyfile = shared.args.ssl_keyfile |
| | ssl_verify = True if (ssl_keyfile and ssl_certfile) else False |
| |
|
| | def on_start(public_url: str): |
| | public_url = public_url.replace('https://', 'wss://') |
| | logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n') |
| |
|
| | if share: |
| | try: |
| | try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) |
| | except Exception as e: |
| | print(e) |
| | else: |
| | if ssl_verify: |
| | logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n') |
| | else: |
| | logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n') |
| |
|
| | asyncio.run(_run(host=address, port=port)) |
| |
|
| |
|
| | def start_server(port: int, share: bool = False, tunnel_id=str): |
| | Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() |
| |
|