KR_API / src /kr_api /routes /ai_routes.py
kauabarros-24
Ai api
598f7ca
from src.kr_api.models.ai_request import Request
import torch
from fastapi import APIRouter
from transformers import BertForSequenceClassification, BertTokenizer
model_name = "src/kr_api/kalium_recommend"
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
router = APIRouter()
@router.post("/ai")
async def ai(request: Request):
text = (
f"This ads is focused in: {request.audience}"
f"This category is: {request.category} "
f"This area is: {request.area}"
f"This sub area of ads: {request.sub_area}"
)
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=255
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
return {"similarity": logits.tolist()}