|
from fastapi import FastAPI, Query, HTTPException |
|
import torch |
|
import re |
|
import time |
|
import logging |
|
import os |
|
from transformers import AutoTokenizer, GenerationConfig |
|
from peft import AutoPeftModelForCausalLM |
|
import gc |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
os.environ["MKL_NUM_THREADS"] = "2" |
|
torch.set_num_threads(2) |
|
torch.set_num_interop_threads(1) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
log = logging.getLogger("news-filter") |
|
|
|
|
|
model_name = "habulaj/filterinstruct180" |
|
log.info("🚀 Carregando modelo e tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
use_fast=True, |
|
padding_side="left" |
|
) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="cpu", |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
use_cache=True, |
|
trust_remote_code=True |
|
) |
|
|
|
model.eval() |
|
log.info("✅ Modelo carregado (eval mode).") |
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=128, |
|
temperature=1.0, |
|
do_sample=False, |
|
num_beams=1, |
|
use_cache=True, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2, |
|
repetition_penalty=1.1, |
|
length_penalty=1.0 |
|
) |
|
|
|
|
|
app = FastAPI(title="News Filter JSON API") |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "News Filter JSON API is running!", "docs": "/docs"} |
|
|
|
|
|
def infer_filter(title, content): |
|
log.info(f"🧠 Inferência iniciada para: {title}") |
|
start_time = time.time() |
|
|
|
chat_prompt = build_chat_prompt(title, content) |
|
|
|
inputs = tokenizer( |
|
chat_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=512, |
|
padding=False, |
|
add_special_tokens=False |
|
) |
|
|
|
input_ids = inputs.input_ids |
|
attention_mask = inputs.attention_mask |
|
|
|
with torch.no_grad(), torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
num_return_sequences=1, |
|
output_scores=False, |
|
return_dict_in_generate=False |
|
) |
|
|
|
generated_tokens = outputs[0][len(input_ids[0]):] |
|
generated = tokenizer.decode( |
|
generated_tokens, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True |
|
) |
|
|
|
log.info("📤 Resultado gerado:") |
|
log.info(generated) |
|
|
|
json_result = extract_json(generated) |
|
|
|
duration = time.time() - start_time |
|
log.info(f"✅ JSON extraído em {duration:.2f}s") |
|
|
|
|
|
del outputs, generated_tokens, inputs |
|
gc.collect() |
|
|
|
if json_result: |
|
return json_result |
|
else: |
|
raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.") |
|
|
|
def build_chat_prompt(title: str, content: str) -> str: |
|
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
|
Analyze the news title and content, and return the filters in JSON format with the defined fields. |
|
|
|
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text. |
|
|
|
Title: "{title}" |
|
Content: "{content}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" |
|
|
|
def extract_json(text): |
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) |
|
if match: |
|
json_text = match.group(0) |
|
|
|
|
|
json_text = re.sub(r"'", '"', json_text) |
|
json_text = re.sub(r'\bTrue\b', 'true', json_text) |
|
json_text = re.sub(r'\bFalse\b', 'false', json_text) |
|
json_text = re.sub(r",\s*}", "}", json_text) |
|
json_text = re.sub(r",\s*]", "]", json_text) |
|
return json_text.strip() |
|
return text |
|
|
|
|
|
@app.get("/filter") |
|
def get_filter( |
|
title: str = Query(..., description="News title"), |
|
content: str = Query(..., description="News content") |
|
): |
|
try: |
|
json_output = infer_filter(title, content) |
|
import json |
|
try: |
|
parsed = json.loads(json_output) |
|
return {"result": parsed} |
|
except json.JSONDecodeError as e: |
|
log.error(f"❌ Erro ao parsear JSON: {e}") |
|
return {"result": json_output, "warning": "JSON returned as string due to parsing error"} |
|
except HTTPException as e: |
|
raise e |
|
except Exception as e: |
|
log.exception("❌ Erro inesperado:") |
|
raise HTTPException(status_code=500, detail="Internal server error during inference.") |
|
|
|
@app.on_event("startup") |
|
async def warmup(): |
|
log.info("🔥 Executando warmup...") |
|
try: |
|
infer_filter("Test title", "Test content") |
|
log.info("✅ Warmup concluído.") |
|
except Exception as e: |
|
log.warning(f"⚠️ Warmup falhou: {e}") |