|
from fastapi import FastAPI, Query, HTTPException |
|
import torch |
|
import re |
|
import time |
|
import logging |
|
from transformers import AutoTokenizer |
|
from peft import AutoPeftModelForCausalLM |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
) |
|
log = logging.getLogger("news-filter") |
|
|
|
|
|
model_name = "habulaj/filter" |
|
log.info("🚀 Iniciando carregamento do modelo e tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
log.info("✅ Tokenizer carregado.") |
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="cpu", |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
) |
|
model.eval() |
|
log.info("✅ Modelo carregado e em modo eval.") |
|
|
|
try: |
|
log.info("✅ Modelo compilado com torch.compile.") |
|
except Exception as e: |
|
log.warning(f"⚠️ torch.compile indisponível: {e}") |
|
|
|
|
|
app = FastAPI(title="News Filter JSON API") |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "News Filter JSON API is running!", "docs": "/docs"} |
|
|
|
|
|
def infer_filter(title, content): |
|
prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields. |
|
|
|
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text. |
|
|
|
Title: "{title}" |
|
Content: "{content}" |
|
""" |
|
|
|
log.info(f"🧠 Iniciando inferência para notícia:\n📰 Title: {title}\n📝 Content: {content[:100]}...") |
|
start_time = time.time() |
|
|
|
inputs = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=512, |
|
padding=False, |
|
) |
|
input_ids = inputs.input_ids.to("cpu") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=100, |
|
temperature=1.0, |
|
do_sample=True, |
|
top_p=0.9, |
|
num_beams=1, |
|
early_stopping=True, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
generated = decoded[len(prompt):].strip() |
|
|
|
log.info("📤 Resposta bruta decodificada:") |
|
log.info(generated) |
|
|
|
match = re.search(r"\{.*\}", generated, re.DOTALL) |
|
if match: |
|
json_result = match.group(0) |
|
duration = time.time() - start_time |
|
log.info(f"✅ JSON extraído com sucesso em {duration:.2f}s") |
|
return json_result |
|
else: |
|
log.warning("⚠️ Não foi possível extrair JSON.") |
|
return "⚠️ Failed to extract JSON. Output:\n" + generated |
|
|
|
|
|
@app.get("/filter") |
|
def get_filter( |
|
title: str = Query(..., description="Title of the news"), |
|
content: str = Query(..., description="Content of the news") |
|
): |
|
try: |
|
json_output = infer_filter(title, content) |
|
import json |
|
return json.loads(json_output) |
|
except json.JSONDecodeError: |
|
log.error("❌ Erro ao fazer parse do JSON retornado.") |
|
return {"raw_output": json_output} |
|
except Exception as e: |
|
log.exception("❌ Erro inesperado durante a inferência:") |
|
raise HTTPException(status_code=422, detail=str(e)) |