File size: 3,386 Bytes
aa6da07
 
 
a15c41a
 
99ff4e1
 
aa6da07
a15c41a
3988dcb
a15c41a
 
3988dcb
2560f7d
3988dcb
a15c41a
99ff4e1
2c0cc11
99ff4e1
 
2c0cc11
a15c41a
 
99ff4e1
2c0cc11
3988dcb
2c0cc11
 
3988dcb
a15c41a
2c0cc11
3988dcb
2c0cc11
3988dcb
aa6da07
 
 
 
99ff4e1
aa6da07
3988dcb
99ff4e1
aa6da07
 
 
 
 
 
 
 
3988dcb
a15c41a
 
3988dcb
2c0cc11
a15c41a
2c0cc11
 
a15c41a
3988dcb
2c0cc11
99ff4e1
3988dcb
a15c41a
99ff4e1
 
 
3988dcb
a15c41a
 
3988dcb
 
 
a15c41a
99ff4e1
2c0cc11
aa6da07
a15c41a
99ff4e1
 
a15c41a
3988dcb
a15c41a
 
99ff4e1
 
a15c41a
3988dcb
 
a15c41a
99ff4e1
3988dcb
99ff4e1
aa6da07
3988dcb
aa6da07
 
3988dcb
 
aa6da07
 
99ff4e1
2560f7d
2c0cc11
 
3988dcb
2c0cc11
aa6da07
3988dcb
99ff4e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, Query, HTTPException
import torch
import re
import time
import logging
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM

# -------- LOGGING CONFIG --------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("news-filter")

# -------- LOAD MODEL --------
model_name = "habulaj/filter"
log.info("🚀 Carregando modelo e tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoPeftModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
)
model.eval()
log.info("✅ Modelo carregado (eval mode).")

try:
    model = torch.compile(model, mode="reduce-overhead")
    log.info("✅ Modelo compilado com torch.compile.")
except Exception as e:
    log.warning(f"⚠️ torch.compile não disponível: {e}")

# -------- FASTAPI INIT --------
app = FastAPI(title="News Filter JSON API")

@app.get("/")
def read_root():
    return {"message": "News Filter JSON API is running!", "docs": "/docs"}

# -------- INFERENCE --------
def infer_filter(title, content):
    prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.

Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.

Title: "{title}"
Content: "{content}"
"""

    log.info(f"🧠 Inferência iniciada para: {title}")
    start_time = time.time()

    # Tokenização + attention mask
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=True,
    )
    input_ids = inputs.input_ids.to("cpu")
    attention_mask = inputs.attention_mask.to("cpu")

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=100,
            temperature=1.0,
            do_sample=False,  # Greedy decoding
            top_k=50,         # Razoável para limitar
            no_repeat_ngram_size=2,
            num_beams=1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generated = decoded[len(prompt):].strip()

    log.info("📤 Resultado gerado:")
    log.info(generated)

    match = re.search(r"\{.*\}", generated, re.DOTALL)
    if match:
        duration = time.time() - start_time
        json_result = match.group(0)
        log.info(f"✅ JSON extraído em {duration:.2f}s")
        return json_result
    else:
        log.warning("⚠️ Falha ao extrair JSON.")
        return "⚠️ Failed to extract JSON. Output:\n" + generated

# -------- API --------
@app.get("/filter")
def get_filter(
    title: str = Query(..., description="News title"),
    content: str = Query(..., description="News content")
):
    try:
        json_output = infer_filter(title, content)
        import json
        return json.loads(json_output)
    except json.JSONDecodeError:
        log.error("❌ Erro ao fazer parse do JSON.")
        return {"raw_output": json_output}
    except Exception as e:
        log.exception("❌ Erro inesperado:")
        raise HTTPException(status_code=422, detail=str(e))