File size: 3,602 Bytes
aa6da07
 
 
99ff4e1
 
aa6da07
99ff4e1
 
 
 
2c0cc11
 
99ff4e1
 
2c0cc11
 
 
99ff4e1
2c0cc11
 
 
 
 
 
 
 
 
 
 
aa6da07
 
 
 
 
 
 
99ff4e1
aa6da07
2c0cc11
99ff4e1
2c0cc11
 
 
 
 
aa6da07
 
 
 
 
 
 
 
2c0cc11
 
 
 
 
 
 
 
99ff4e1
 
 
2c0cc11
99ff4e1
 
2c0cc11
 
99ff4e1
 
2c0cc11
 
99ff4e1
2c0cc11
aa6da07
99ff4e1
 
 
 
 
 
 
 
 
2c0cc11
 
 
 
 
99ff4e1
 
aa6da07
 
 
 
 
 
 
 
99ff4e1
2c0cc11
 
 
 
 
 
aa6da07
99ff4e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, Query, HTTPException
import torch
import re
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM

# Carrega modelo e tokenizer da Hugging Face - LoRA fine-tuned
model_name = "habulaj/filter"
print("Carregando tokenizer e modelo (CPU)...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Otimizações de performance
model = AutoPeftModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.float32,  # float32 é mais rápido em CPU
    low_cpu_mem_usage=True,     # Reduz uso de memória
)
model.eval()

# Compilação do modelo para otimizar (PyTorch 2.0+)
try:
    model = torch.compile(model, mode="reduce-overhead")
    print("✅ Modelo compilado com torch.compile")
except Exception as e:
    print(f"⚠️ torch.compile não disponível: {e}")

# Cache para prompts similares
prompt_cache = {}

# -------- FASTAPI --------
app = FastAPI(title="News Filter JSON API")

# -------- ROOT ENDPOINT --------
@app.get("/")
def read_root():
    return {"message": "News Filter JSON API is running!", "docs": "/docs"}

# Função para inferência otimizada
def infer_filter(title, content):
    # Cache key simples
    cache_key = hash((title[:50], content[:100]))
    if cache_key in prompt_cache:
        return prompt_cache[cache_key]
    
    prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.

Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.

Title: "{title}"
Content: "{content}"
"""

    # Otimizações de tokenização
    inputs = tokenizer(
        prompt, 
        return_tensors="pt",
        truncation=True,
        max_length=512,  # Limita tamanho do input
        padding=False    # Não faz padding desnecessário
    )
    input_ids = inputs.input_ids.to("cpu")
    
    with torch.no_grad():
        # Configurações otimizadas para velocidade
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=100,      # Reduzido de 128 para 100
            temperature=1.0,         # Reduzido para ser mais determinístico
            do_sample=True,
            top_p=0.9,
            num_beams=1,            # Beam search = 1 (greedy) é mais rápido
            early_stopping=True,    # Para quando encontrar EOS
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Remove prompt do output
    generated = decoded[len(prompt):].strip()
    
    # Extrai JSON
    match = re.search(r"\{.*\}", generated, re.DOTALL)
    if match:
        result = match.group(0)
        # Cache o resultado (limitado a 100 entradas)
        if len(prompt_cache) < 100:
            prompt_cache[cache_key] = result
        return result
    else:
        return "⚠️ Failed to extract JSON. Output:\n" + generated

# -------- API ROUTE --------
@app.get("/filter")
def get_filter(
    title: str = Query(..., description="Title of the news"),
    content: str = Query(..., description="Content of the news")
):
    try:
        json_output = infer_filter(title, content)
        import json
        # Retorna como dados brutos (parse do JSON)
        return json.loads(json_output)
    except json.JSONDecodeError:
        # Se não conseguir fazer parse, retorna como string
        return {"raw_output": json_output}
    except Exception as e:
        raise HTTPException(status_code=422, detail=str(e))