KoFace-AI / clip /models /ko_clip.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
raw
history blame contribute delete
908 Bytes
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class KoCLIPWrapper(nn.Module):
def __init__(self):
super().__init__()
self.model_name = "Bingsu/clip-vit-base-patch32-ko"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)
def loss(self, inputs):
outputs = self(inputs)
return outputs.loss
def text_encode(self, text, tokenize):
if tokenize:
tokens = self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors="pt")
else:
tokens = text
tokens = tokens.to(self.model.device)
return self.model.get_text_features(**tokens)
def forward(self, inputs):
outputs = self.model(**inputs, return_loss=True)
return outputs # [1, 512] , [1, 512]