filter / app.py
habulaj's picture
Update app.py
3988dcb verified
raw
history blame
3.39 kB
from fastapi import FastAPI, Query, HTTPException
import torch
import re
import time
import logging
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
# -------- LOGGING CONFIG --------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("news-filter")
# -------- LOAD MODEL --------
model_name = "habulaj/filter"
log.info("🚀 Carregando modelo e tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoPeftModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
model.eval()
log.info("✅ Modelo carregado (eval mode).")
try:
model = torch.compile(model, mode="reduce-overhead")
log.info("✅ Modelo compilado com torch.compile.")
except Exception as e:
log.warning(f"⚠️ torch.compile não disponível: {e}")
# -------- FASTAPI INIT --------
app = FastAPI(title="News Filter JSON API")
@app.get("/")
def read_root():
return {"message": "News Filter JSON API is running!", "docs": "/docs"}
# -------- INFERENCE --------
def infer_filter(title, content):
prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
Title: "{title}"
Content: "{content}"
"""
log.info(f"🧠 Inferência iniciada para: {title}")
start_time = time.time()
# Tokenização + attention mask
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
input_ids = inputs.input_ids.to("cpu")
attention_mask = inputs.attention_mask.to("cpu")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
temperature=1.0,
do_sample=False, # Greedy decoding
top_k=50, # Razoável para limitar
no_repeat_ngram_size=2,
num_beams=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = decoded[len(prompt):].strip()
log.info("📤 Resultado gerado:")
log.info(generated)
match = re.search(r"\{.*\}", generated, re.DOTALL)
if match:
duration = time.time() - start_time
json_result = match.group(0)
log.info(f"✅ JSON extraído em {duration:.2f}s")
return json_result
else:
log.warning("⚠️ Falha ao extrair JSON.")
return "⚠️ Failed to extract JSON. Output:\n" + generated
# -------- API --------
@app.get("/filter")
def get_filter(
title: str = Query(..., description="News title"),
content: str = Query(..., description="News content")
):
try:
json_output = infer_filter(title, content)
import json
return json.loads(json_output)
except json.JSONDecodeError:
log.error("❌ Erro ao fazer parse do JSON.")
return {"raw_output": json_output}
except Exception as e:
log.exception("❌ Erro inesperado:")
raise HTTPException(status_code=422, detail=str(e))