filter / app.py
habulaj's picture
Update app.py
480aadb verified
raw
history blame
3.36 kB
from fastapi import FastAPI, Query, HTTPException
import torch
import re
import time
import logging
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
# -------- LOGGING CONFIG --------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
log = logging.getLogger("news-filter")
# -------- CARREGAMENTO DE MODELO --------
model_name = "habulaj/filter"
log.info("🚀 Iniciando carregamento do modelo e tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
log.info("✅ Tokenizer carregado.")
model = AutoPeftModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
model.eval()
log.info("✅ Modelo carregado e em modo eval.")
try:
log.info("✅ Modelo compilado com torch.compile.")
except Exception as e:
log.warning(f"⚠️ torch.compile indisponível: {e}")
# -------- FASTAPI --------
app = FastAPI(title="News Filter JSON API")
@app.get("/")
def read_root():
return {"message": "News Filter JSON API is running!", "docs": "/docs"}
# -------- INFERÊNCIA --------
def infer_filter(title, content):
prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
Title: "{title}"
Content: "{content}"
"""
log.info(f"🧠 Iniciando inferência para notícia:\n📰 Title: {title}\n📝 Content: {content[:100]}...")
start_time = time.time()
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512,
padding=False,
)
input_ids = inputs.input_ids.to("cpu")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=100,
temperature=1.0,
do_sample=True,
top_p=0.9,
num_beams=1,
early_stopping=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = decoded[len(prompt):].strip()
log.info("📤 Resposta bruta decodificada:")
log.info(generated)
match = re.search(r"\{.*\}", generated, re.DOTALL)
if match:
json_result = match.group(0)
duration = time.time() - start_time
log.info(f"✅ JSON extraído com sucesso em {duration:.2f}s")
return json_result
else:
log.warning("⚠️ Não foi possível extrair JSON.")
return "⚠️ Failed to extract JSON. Output:\n" + generated
# -------- ENDPOINT --------
@app.get("/filter")
def get_filter(
title: str = Query(..., description="Title of the news"),
content: str = Query(..., description="Content of the news")
):
try:
json_output = infer_filter(title, content)
import json
return json.loads(json_output)
except json.JSONDecodeError:
log.error("❌ Erro ao fazer parse do JSON retornado.")
return {"raw_output": json_output}
except Exception as e:
log.exception("❌ Erro inesperado durante a inferência:")
raise HTTPException(status_code=422, detail=str(e))