Spaces:
No application file
No application file
from src.kr_api.models.ai_request import Request | |
import torch | |
from fastapi import APIRouter | |
from transformers import BertForSequenceClassification, BertTokenizer | |
model_name = "src/kr_api/kalium_recommend" | |
model = BertForSequenceClassification.from_pretrained(model_name) | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
router = APIRouter() | |
async def ai(request: Request): | |
text = ( | |
f"This ads is focused in: {request.audience}" | |
f"This category is: {request.category} " | |
f"This area is: {request.area}" | |
f"This sub area of ads: {request.sub_area}" | |
) | |
inputs = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
return_tensors="pt", | |
padding='max_length', | |
truncation=True, | |
max_length=255 | |
) | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
return {"similarity": logits.tolist()} |