File size: 5,085 Bytes
aa6da07
 
 
a15c41a
 
cda1138
6658bef
 
59d7833
cda1138
6658bef
cda1138
84aad09
cda1138
 
 
aa6da07
59d7833
cda1138
a15c41a
 
59d7833
eeb5ab3
3988dcb
a15c41a
59d7833
 
 
 
 
6658bef
cda1138
 
2c0cc11
cda1138
99ff4e1
2c0cc11
84aad09
a15c41a
59d7833
cda1138
99ff4e1
59d7833
2c0cc11
59d7833
2c0cc11
cda1138
84aad09
6658bef
 
 
cda1138
 
 
59d7833
cda1138
6658bef
cda1138
8efc71c
59d7833
aa6da07
 
 
 
99ff4e1
aa6da07
6658bef
 
 
 
 
59d7833
 
 
 
6658bef
59d7833
 
 
 
 
 
 
 
84aad09
 
 
59d7833
 
84aad09
59d7833
 
 
84aad09
 
59d7833
 
 
 
 
 
 
 
 
 
 
84aad09
 
 
59d7833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84aad09
 
59d7833
99ff4e1
59d7833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastapi import FastAPI, Query, HTTPException
import torch
import re
import time
import logging
import os
from transformers import AutoTokenizer, GenerationConfig
from peft import AutoPeftModelForCausalLM
import gc

# -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
torch.set_num_threads(2)
torch.set_num_interop_threads(1)

# -------- LOGGING CONFIG --------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("news-filter")

# -------- LOAD MODEL --------
model_name = "habulaj/filterinstruct180"
log.info("🚀 Carregando modelo e tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True,
    padding_side="left"
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoPeftModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_cache=True,
    trust_remote_code=True
)

model.eval()
log.info("✅ Modelo carregado (eval mode).")

generation_config = GenerationConfig(
    max_new_tokens=128,
    temperature=1.0,
    do_sample=False,
    num_beams=1,
    use_cache=True,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
    no_repeat_ngram_size=2,
    repetition_penalty=1.1,
    length_penalty=1.0
)

# -------- FASTAPI INIT --------
app = FastAPI(title="News Filter JSON API")

@app.get("/")
def read_root():
    return {"message": "News Filter JSON API is running!", "docs": "/docs"}

# -------- INFERÊNCIA --------
def infer_filter(title, content):
    log.info(f"🧠 Inferência iniciada para: {title}")
    start_time = time.time()

    chat_prompt = build_chat_prompt(title, content)

    inputs = tokenizer(
        chat_prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=False,
        add_special_tokens=False
    )

    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    with torch.no_grad(), torch.inference_mode():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=generation_config,
            num_return_sequences=1,
            output_scores=False,
            return_dict_in_generate=False
        )

    generated_tokens = outputs[0][len(input_ids[0]):]
    generated = tokenizer.decode(
        generated_tokens,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    log.info("📤 Resultado gerado:")
    log.info(generated)

    json_result = extract_json(generated)

    duration = time.time() - start_time
    log.info(f"✅ JSON extraído em {duration:.2f}s")

    # Limpeza de memória
    del outputs, generated_tokens, inputs
    gc.collect()

    if json_result:
        return json_result
    else:
        raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.")

def build_chat_prompt(title: str, content: str) -> str:
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Analyze the news title and content, and return the filters in JSON format with the defined fields.

Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.

Title: "{title}"
Content: "{content}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

def extract_json(text):
    match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
    if match:
        json_text = match.group(0)

        # Conversões comuns
        json_text = re.sub(r"'", '"', json_text)
        json_text = re.sub(r'\bTrue\b', 'true', json_text)
        json_text = re.sub(r'\bFalse\b', 'false', json_text)
        json_text = re.sub(r",\s*}", "}", json_text)
        json_text = re.sub(r",\s*]", "]", json_text)
        return json_text.strip()
    return text

# -------- API ROUTE --------
@app.get("/filter")
def get_filter(
    title: str = Query(..., description="News title"),
    content: str = Query(..., description="News content")
):
    try:
        json_output = infer_filter(title, content)
        import json
        try:
            parsed = json.loads(json_output)
            return {"result": parsed}
        except json.JSONDecodeError as e:
            log.error(f"❌ Erro ao parsear JSON: {e}")
            return {"result": json_output, "warning": "JSON returned as string due to parsing error"}
    except HTTPException as e:
        raise e
    except Exception as e:
        log.exception("❌ Erro inesperado:")
        raise HTTPException(status_code=500, detail="Internal server error during inference.")

@app.on_event("startup")
async def warmup():
    log.info("🔥 Executando warmup...")
    try:
        infer_filter("Test title", "Test content")
        log.info("✅ Warmup concluído.")
    except Exception as e:
        log.warning(f"⚠️ Warmup falhou: {e}")