| | import argparse |
| | import markdown2 |
| | import os |
| | import sys |
| | import uvicorn |
| |
|
| | from pathlib import Path |
| | from typing import Union |
| |
|
| | from fastapi import FastAPI, Depends, HTTPException |
| | from fastapi.responses import HTMLResponse |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from pydantic import BaseModel, Field |
| | from sse_starlette.sse import EventSourceResponse, ServerSentEvent |
| | from tclogger import logger |
| |
|
| | from constants.models import AVAILABLE_MODELS_DICTS, PRO_MODELS |
| | from constants.envs import CONFIG, SECRETS |
| | from networks.exceptions import HfApiException, INVALID_API_KEY_ERROR |
| |
|
| | from messagers.message_composer import MessageComposer |
| | from mocks.stream_chat_mocker import stream_chat_mock |
| |
|
| | from networks.huggingface_streamer import HuggingfaceStreamer |
| | from networks.huggingchat_streamer import HuggingchatStreamer |
| | from networks.openai_streamer import OpenaiStreamer |
| |
|
| |
|
| | class ChatAPIApp: |
| | def __init__(self): |
| | self.app = FastAPI( |
| | docs_url="/", |
| | title=CONFIG["app_name"], |
| | swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
| | version=CONFIG["version"], |
| | ) |
| | self.setup_routes() |
| |
|
| | def get_available_models(self): |
| | return {"object": "list", "data": AVAILABLE_MODELS_DICTS} |
| |
|
| | def extract_api_key( |
| | credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()), |
| | ): |
| | api_key = None |
| | if credentials: |
| | api_key = credentials.credentials |
| | env_api_key = SECRETS["HF_LLM_API_KEY"] |
| | return api_key |
| |
|
| | def auth_api_key(self, api_key: str): |
| | env_api_key = SECRETS["HF_LLM_API_KEY"] |
| |
|
| | |
| | if not env_api_key: |
| | return None |
| | |
| | if api_key and api_key.startswith("hf_"): |
| | return api_key |
| | |
| | if str(api_key) == str(env_api_key): |
| | return None |
| |
|
| | raise INVALID_API_KEY_ERROR |
| |
|
| | class ChatCompletionsPostItem(BaseModel): |
| |
|
| | model: str = Field( |
| | default="nous-mixtral-8x7b", |
| | description="(str) `nous-mixtral-8x7b`", |
| | ) |
| | messages: list = Field( |
| | default=[{"role": "user", "content": "Hello, who are you?"}], |
| | description="(list) Messages", |
| | ) |
| | temperature: Union[float, None] = Field( |
| | default=0.5, |
| | description="(float) Temperature", |
| | ) |
| | top_p: Union[float, None] = Field( |
| | default=0.95, |
| | description="(float) top p", |
| | ) |
| | max_tokens: Union[int, None] = Field( |
| | default=-1, |
| | description="(int) Max tokens", |
| | ) |
| | use_cache: bool = Field( |
| | default=False, |
| | description="(bool) Use cache", |
| | ) |
| | stream: bool = Field( |
| | default=True, |
| | description="(bool) Stream", |
| | ) |
| |
|
| | def chat_completions( |
| | self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) |
| | ): |
| | try: |
| | api_key = self.auth_api_key(api_key) |
| |
|
| | if item.model == "gpt-3.5-turbo": |
| | streamer = OpenaiStreamer() |
| | stream_response = streamer.chat_response(messages=item.messages) |
| | elif item.model in PRO_MODELS: |
| | streamer = HuggingchatStreamer(model=item.model) |
| | stream_response = streamer.chat_response( |
| | messages=item.messages, |
| | ) |
| | else: |
| | streamer = HuggingfaceStreamer(model=item.model) |
| | composer = MessageComposer(model=item.model) |
| | composer.merge(messages=item.messages) |
| | stream_response = streamer.chat_response( |
| | prompt=composer.merged_str, |
| | temperature=item.temperature, |
| | top_p=item.top_p, |
| | max_new_tokens=item.max_tokens, |
| | api_key=api_key, |
| | use_cache=item.use_cache, |
| | ) |
| |
|
| | if item.stream: |
| | event_source_response = EventSourceResponse( |
| | streamer.chat_return_generator(stream_response), |
| | media_type="text/event-stream", |
| | ping=2000, |
| | ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), |
| | ) |
| | return event_source_response |
| | else: |
| | data_response = streamer.chat_return_dict(stream_response) |
| | return data_response |
| | except HfApiException as e: |
| | raise HTTPException(status_code=e.status_code, detail=e.detail) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | def get_readme(self): |
| | readme_path = Path(__file__).parents[1] / "README.md" |
| | with open(readme_path, "r", encoding="utf-8") as rf: |
| | readme_str = rf.read() |
| | readme_html = markdown2.markdown( |
| | readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"] |
| | ) |
| | return readme_html |
| |
|
| | def setup_routes(self): |
| | for prefix in ["", "/v1", "/api", "/api/v1"]: |
| | if prefix in ["/api/v1"]: |
| | include_in_schema = True |
| | else: |
| | include_in_schema = False |
| |
|
| | self.app.get( |
| | prefix + "/models", |
| | summary="Get available models", |
| | include_in_schema=include_in_schema, |
| | )(self.get_available_models) |
| |
|
| | self.app.post( |
| | prefix + "/chat/completions", |
| | summary="Chat completions in conversation session", |
| | include_in_schema=include_in_schema, |
| | )(self.chat_completions) |
| | self.app.get( |
| | "/readme", |
| | summary="README of HF LLM API", |
| | response_class=HTMLResponse, |
| | include_in_schema=False, |
| | )(self.get_readme) |
| |
|
| |
|
| | class ArgParser(argparse.ArgumentParser): |
| | def __init__(self, *args, **kwargs): |
| | super(ArgParser, self).__init__(*args, **kwargs) |
| |
|
| | self.add_argument( |
| | "-s", |
| | "--host", |
| | type=str, |
| | default=CONFIG["host"], |
| | help=f"Host for {CONFIG['app_name']}", |
| | ) |
| | self.add_argument( |
| | "-p", |
| | "--port", |
| | type=int, |
| | default=CONFIG["port"], |
| | help=f"Port for {CONFIG['app_name']}", |
| | ) |
| |
|
| | self.add_argument( |
| | "-d", |
| | "--dev", |
| | default=False, |
| | action="store_true", |
| | help="Run in dev mode", |
| | ) |
| |
|
| | self.args = self.parse_args(sys.argv[1:]) |
| |
|
| |
|
| | app = ChatAPIApp().app |
| |
|
| | if __name__ == "__main__": |
| | args = ArgParser().args |
| | if args.dev: |
| | uvicorn.run("__main__:app", host=args.host, port=args.port, reload=True) |
| | else: |
| | uvicorn.run("__main__:app", host=args.host, port=args.port, reload=False) |
| |
|
| | |
| | |
| |
|