Spaces:
Sleeping
Sleeping
File size: 2,458 Bytes
2184c4e 4ddb23c 2184c4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import torch
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
class ClipModel:
def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None):
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
self.model = ChineseCLIPModel.from_pretrained(model_name)
if model_path is None:
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model = self.model.to(self.device)
self.model.eval()
self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
def clip_result(self, image_path, vocab_path='./chiikawa/word_list.txt', top_k=3):
# Load image
image = Image.open(image_path)
# Load Chinese vocabulary
with open(vocab_path, 'r', encoding='utf-8') as f:
vocab = [line.strip() for line in f.readlines()]
# Process images and texts
batch_size = 16 # Process 16 vocab at a time
similarities = []
# Release unused memory
torch.cuda.empty_cache()
with torch.no_grad():
for i in range(0, len(vocab), batch_size):
batch_vocab = vocab[i:i + batch_size]
inputs = self.processor(
text=batch_vocab,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# Inference and similarity calculation
outputs = self.model(**inputs)
image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
similarities.append(similarity)
# Merge all similarities
similarity = torch.cat(similarities, dim=0)
# Find top-3 similarities
top_k_indices = torch.topk(similarity, top_k).indices.tolist()
top_k_words = [vocab[idx] for idx in top_k_indices]
# 6. 輸出最接近的前3名中文詞彙
return top_k_words
|