Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel | |
| import torch | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| import time | |
| def CLIP_model(): | |
| global model, token, processor | |
| model_id = 'openai/clip-vit-base-patch32' | |
| model = CLIPModel.from_pretrained(model_id) | |
| token = CLIPTokenizerFast.from_pretrained(model_id) | |
| processor = CLIPProcessor.from_pretrained(model_id) | |
| def load_data(): | |
| global data | |
| data = load_dataset( | |
| 'frgfm/imagenette', | |
| 'full_size', | |
| split = 'train', | |
| ignore_verifications = False | |
| ) | |
| def embedding_input(text_input): | |
| token_input = token(text_input, return_tensors = "pt") | |
| text_embedd = model.get_text_features(**token_input) | |
| return text_embedd | |
| def embedding_img(): | |
| global img_arr, images | |
| images = data['image'] | |
| batch_size = 10 | |
| img_arr = None | |
| for i in tqdm(range(0, len(images), batch_size)): | |
| batch = images[i:i+batch_size] | |
| batch = processor( | |
| text = None, | |
| images = batch, | |
| return_tensors = 'pt', | |
| padding = True | |
| )['pixel_values'] | |
| batch_emb = model.get_image_features(pixel_values=batch) | |
| batch_emb = batch_emb.squeeze(0) | |
| batch_emb = batch_emb.detach().numpy() | |
| if img_arr is None: | |
| img_arr = batch_emb | |
| else: | |
| img_arr = np.concatenate((img_arr, batch_emb), axis = 0) | |
| return images, img_arr | |
| def main(): | |
| CLIP_model() | |
| load_data() | |
| embedding_img() | |
| iface = gr.Interface(fn = process, inputs = "text", outputs = "image") | |
| iface.launch(inline = False) | |
| def process(text): | |
| text_input = embedding_input(text) | |
| image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T | |
| text_emb = text_input.detach().numpy() | |
| scores = np.dot(text_emb, image_emb.T) | |
| idx = np.argsort(-scores[0])[0] | |
| return images[idx] | |
| if __name__ == "__main__": | |
| main() | |