|
from fastapi import FastAPI, Query, HTTPException |
|
import torch |
|
import re |
|
import time |
|
import logging |
|
import os |
|
from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig |
|
from peft import AutoPeftModelForCausalLM |
|
import gc |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
os.environ["MKL_NUM_THREADS"] = "2" |
|
torch.set_num_threads(2) |
|
torch.set_num_interop_threads(1) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
log = logging.getLogger("news-filter") |
|
|
|
|
|
model_name = "habulaj/filter" |
|
log.info("🚀 Carregando modelo e tokenizer...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
use_fast=True, |
|
padding_side="left" |
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="cpu", |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
use_cache=True, |
|
trust_remote_code=True |
|
) |
|
|
|
model.eval() |
|
log.info("✅ Modelo carregado (eval mode).") |
|
|
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=100, |
|
temperature=1.0, |
|
do_sample=False, |
|
num_beams=1, |
|
use_cache=True, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2, |
|
repetition_penalty=1.1, |
|
length_penalty=1.0 |
|
) |
|
|
|
|
|
try: |
|
model = torch.compile( |
|
model, |
|
mode="reduce-overhead", |
|
fullgraph=True, |
|
dynamic=False |
|
) |
|
log.info("✅ Modelo compilado com torch.compile.") |
|
except Exception as e: |
|
log.warning(f"⚠️ torch.compile não disponível: {e}") |
|
|
|
|
|
app = FastAPI(title="News Filter JSON API") |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "News Filter JSON API is running!", "docs": "/docs"} |
|
|
|
|
|
def infer_filter(title, content): |
|
|
|
prompt = f"""Analyze the news and return a valid JSON object with double quotes for all keys and string values. |
|
Title: "{title}" |
|
Content: "{content}" |
|
|
|
Return only valid JSON:""" |
|
|
|
log.info(f"🧠 Inferência iniciada para: {title}") |
|
start_time = time.time() |
|
|
|
|
|
inputs = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=384, |
|
padding=False, |
|
add_special_tokens=True, |
|
) |
|
|
|
input_ids = inputs.input_ids |
|
attention_mask = inputs.attention_mask |
|
|
|
|
|
with torch.no_grad(): |
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
|
|
num_return_sequences=1, |
|
output_scores=False, |
|
return_dict_in_generate=False, |
|
) |
|
|
|
|
|
generated_tokens = outputs[0][len(input_ids[0]):] |
|
generated = tokenizer.decode( |
|
generated_tokens, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True |
|
) |
|
|
|
log.info("📤 Resultado gerado:") |
|
log.info(generated) |
|
|
|
|
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', generated, re.DOTALL) |
|
if match: |
|
duration = time.time() - start_time |
|
json_result = match.group(0) |
|
|
|
|
|
json_result = fix_json_format(json_result) |
|
|
|
log.info(f"✅ JSON extraído em {duration:.2f}s") |
|
|
|
|
|
del outputs, generated_tokens, inputs |
|
gc.collect() |
|
|
|
return json_result |
|
else: |
|
log.warning("⚠️ Falha ao extrair JSON.") |
|
|
|
del outputs, generated_tokens, inputs |
|
gc.collect() |
|
raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.") |
|
|
|
def fix_json_format(json_str): |
|
"""Corrige formatação comum de JSON gerado por LLMs""" |
|
|
|
json_str = re.sub(r'\n\s*', ' ', json_str) |
|
|
|
|
|
json_str = re.sub(r"'([^']*)':", r'"\1":', json_str) |
|
json_str = re.sub(r":\s*'([^']*)'", r': "\1"', json_str) |
|
|
|
|
|
json_str = re.sub(r':\s*True\b', ': true', json_str) |
|
json_str = re.sub(r':\s*False\b', ': false', json_str) |
|
|
|
|
|
json_str = re.sub(r',\s*}', '}', json_str) |
|
json_str = re.sub(r',\s*]', ']', json_str) |
|
|
|
|
|
json_str = re.sub(r'\s+', ' ', json_str) |
|
|
|
return json_str.strip() |
|
|
|
|
|
@app.get("/filter") |
|
def get_filter( |
|
title: str = Query(..., description="News title"), |
|
content: str = Query(..., description="News content") |
|
): |
|
try: |
|
json_output = infer_filter(title, content) |
|
import json |
|
|
|
|
|
try: |
|
parsed_result = json.loads(json_output) |
|
return {"result": parsed_result} |
|
except json.JSONDecodeError as je: |
|
log.error(f"❌ Erro ao parsear JSON: {je}") |
|
log.error(f"JSON problemático: {json_output}") |
|
|
|
|
|
return {"result": json_output, "warning": "JSON returned as string due to parsing error"} |
|
|
|
except HTTPException as he: |
|
raise he |
|
except Exception as e: |
|
log.exception("❌ Erro inesperado:") |
|
raise HTTPException(status_code=500, detail="Internal server error during inference.") |
|
|
|
|
|
@app.on_event("startup") |
|
async def warmup(): |
|
"""Faz um warmup do modelo para otimizar as primeiras execuções""" |
|
log.info("🔥 Executando warmup...") |
|
try: |
|
|
|
infer_filter("Test title", "Test content") |
|
log.info("✅ Warmup concluído.") |
|
except Exception as e: |
|
log.warning(f"⚠️ Warmup falhou: {e}") |