File size: 3,199 Bytes
21db53c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from time import time

import numpy as np
import torch
from PIL import Image
from loguru import logger
from numpy import ndarray
from torch import FloatTensor, no_grad
from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel

from app.Services.lifespan_service import LifespanService
from app.config import config


class TransformersService(LifespanService):
    def __init__(self):
        self.device = config.device
        if self.device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}",
                    self.device, config.model.clip, config.model.bert)
        self._clip_model = CLIPModel.from_pretrained(config.model.clip).to(self.device)
        self._clip_processor = CLIPProcessor.from_pretrained(config.model.clip)
        logger.success("CLIP Model loaded successfully")
        if config.ocr_search.enable:
            self._bert_model = BertModel.from_pretrained(config.model.bert).to(self.device)
            self._bert_tokenizer = BertTokenizer.from_pretrained(config.model.bert)
            logger.success("BERT Model loaded successfully")
        else:
            logger.info("OCR search is disabled. Skipping BERT model loading.")

    @no_grad()
    def get_image_vector(self, image: Image.Image) -> ndarray:
        if image.mode != "RGB":
            image = image.convert("RGB")
        logger.info("Processing image...")
        start_time = time()
        inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device)
        logger.success("Image processed, now Inferring with CLIP model...")
        outputs: FloatTensor = self._clip_model.get_image_features(**inputs)
        logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
        outputs /= outputs.norm(dim=-1, keepdim=True)
        return outputs.numpy(force=True).reshape(-1)

    @no_grad()
    def get_text_vector(self, text: str) -> ndarray:
        logger.info("Processing text...")
        start_time = time()
        inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device)
        logger.success("Text processed, now Inferring with CLIP model...")
        outputs: FloatTensor = self._clip_model.get_text_features(**inputs)
        logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
        outputs /= outputs.norm(dim=-1, keepdim=True)
        return outputs.numpy(force=True).reshape(-1)

    @no_grad()
    def get_bert_vector(self, text: str) -> ndarray:
        start_time = time()
        logger.info("Inferring with BERT model...")
        inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt", truncation=True).to(self.device)
        outputs = self._bert_model(**inputs)
        vector = outputs.last_hidden_state.mean(dim=1).squeeze()
        logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time)
        return vector.cpu().numpy()

    @staticmethod
    def get_random_vector(seed: int | None = None) -> ndarray:
        generator = np.random.default_rng(seed)
        vec = generator.uniform(-1, 1, 768)
        return vec