test / src /modules /models /model_loader.py
Akcom's picture
Optimize for HF spaces
8396f74
import os
import torch
from PIL import Image
from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
from src.utils.singleton import Singleton
from src.modules.config import Config, config as ConfigObj
from src.modules.logger import logger as Logger
class ModelLoader:
__metaclass__ = Singleton
def __init__(self, conf: Config):
#self._model = None
#return # TODO remove this line
#print("Not implemented !!!")
print(conf.model_name)
#revision = "2024-03-04" - verry slow
#revision = "main" #does not work
#revision = "2024-05-08" not work
revision = "2024-08-26"
if conf.gpu_mode:
self._model = AutoModelForCausalLM.from_pretrained(
conf.model_name, trust_remote_code=True, revision=revision,
torch_dtype=torch.bfloat16, cache_dir=conf.models_cache_dir,
device_map={"": "cuda"}, attn_implementation="flash_attention_2"
).to("cuda")
else:
self._model = AutoModelForCausalLM.from_pretrained(
conf.model_name, trust_remote_code=True, revision=revision,
cache_dir=conf.models_cache_dir,
)
self._tokenizer = AutoTokenizer.from_pretrained(
conf.model_name, revision=revision, cache_dir=conf.models_cache_dir
)
self._model.eval()
def image_describe(self, image_path):
image = Image.open(image_path)
enc_image = self._model.encode_image(image)
return self._model.answer_question(enc_image, "Describe this image.", self._tokenizer)
def image_ask(self, image_path, question):
image = Image.open(image_path)
enc_image = self._model.encode_image(image)
return self._model.answer_question(
enc_image, question,
self._tokenizer
)
model_loader = ModelLoader(ConfigObj)