Spaces:
Running
Running
import torch.nn as nn | |
from transformers import AutoModel, AutoTokenizer | |
class KoCLIPWrapper(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model_name = "Bingsu/clip-vit-base-patch32-ko" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModel.from_pretrained(self.model_name) | |
def loss(self, inputs): | |
outputs = self(inputs) | |
return outputs.loss | |
def text_encode(self, text, tokenize): | |
if tokenize: | |
tokens = self.tokenizer(text, padding='max_length', max_length=77, truncation=True, return_tensors="pt") | |
else: | |
tokens = text | |
tokens = tokens.to(self.model.device) | |
return self.model.get_text_features(**tokens) | |
def forward(self, inputs): | |
outputs = self.model(**inputs, return_loss=True) | |
return outputs # [1, 512] , [1, 512] |