Spaces:
No application file
No application file
File size: 1,171 Bytes
598f7ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from src.kr_api.models.ai_request import Request
import torch
from fastapi import APIRouter
from transformers import BertForSequenceClassification, BertTokenizer
model_name = "src/kr_api/kalium_recommend"
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
router = APIRouter()
@router.post("/ai")
async def ai(request: Request):
text = (
f"This ads is focused in: {request.audience}"
f"This category is: {request.category} "
f"This area is: {request.area}"
f"This sub area of ads: {request.sub_area}"
)
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=255
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
return {"similarity": logits.tolist()} |