habulaj commited on
Commit
6658bef
·
verified ·
1 Parent(s): bcf557b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -78
app.py CHANGED
@@ -4,12 +4,14 @@ import re
4
  import time
5
  import logging
6
  import os
7
- from transformers import AutoTokenizer, GenerationConfig
8
- from peft import AutoPeftModelForCausalLM
9
  import gc
10
  import json
 
 
 
 
11
 
12
- # -------- CONFIGS DE OTIMIZAÇÃO --------
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
  os.environ["OMP_NUM_THREADS"] = "2"
15
  os.environ["MKL_NUM_THREADS"] = "2"
@@ -20,43 +22,37 @@ torch.set_num_interop_threads(1)
20
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
21
  log = logging.getLogger("news-filter")
22
 
23
- # -------- MODELO E TOKENIZER --------
24
- model_name = "habulaj/filterinstruct"
25
  log.info("🚀 Carregando modelo e tokenizer...")
26
 
27
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, padding_side="left")
 
 
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
31
- # Aplica chat template
32
- def get_chat_template(tokenizer, chat_template="llama-3.1"):
33
- tokenizer.chat_template = chat_template
34
- return tokenizer
35
-
36
- tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
37
-
38
  model = AutoPeftModelForCausalLM.from_pretrained(
39
  model_name,
40
  device_map="cpu",
41
  torch_dtype=torch.bfloat16,
42
  low_cpu_mem_usage=True,
43
- use_cache=True,
44
  trust_remote_code=True
45
  )
 
46
  model.eval()
47
- log.info("✅ Modelo carregado em modo eval.")
48
 
49
- # -------- CONFIG DE GERAÇÃO --------
50
  generation_config = GenerationConfig(
51
  max_new_tokens=128,
52
- temperature=1.2,
53
- top_p=0.95,
54
- do_sample=True,
55
  use_cache=True,
56
  eos_token_id=tokenizer.eos_token_id,
57
  pad_token_id=tokenizer.eos_token_id,
58
  repetition_penalty=1.1,
59
- no_repeat_ngram_size=2,
60
  )
61
 
62
  # -------- FASTAPI --------
@@ -66,12 +62,34 @@ app = FastAPI(title="News Filter JSON API")
66
  def read_root():
67
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
68
 
69
- # -------- INFERÊNCIA COM TEMPLATE --------
70
- def infer_filter(title, content):
71
- log.info(f"🧠 Iniciando inferência para: {title}")
72
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Histórico com exemplos
 
75
  messages = [
76
  {
77
  "role": "user",
@@ -85,7 +103,21 @@ Content: "Lucasfilm confirmed a new Star Wars movie set to release in 2026, dire
85
  },
86
  {
87
  "role": "assistant",
88
- "content": '{ "death_related": false, "relevance": "high", "global_interest": true, "entity_type": "movie", "entity_name": "Star Wars", "is_promotional": false, "potential_for_viral": true, "urgency_level": "medium", "has_video_content": false }'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  },
90
  {
91
  "role": "user",
@@ -99,74 +131,33 @@ Content: "{content}"
99
  }
100
  ]
101
 
102
- # Tokenização com chat template
 
 
103
  inputs = tokenizer.apply_chat_template(
104
  messages,
105
  tokenize=True,
106
  add_generation_prompt=True,
107
- return_tensors="pt"
108
  ).to("cpu")
109
 
110
  with torch.no_grad(), torch.inference_mode():
111
  outputs = model.generate(
112
  input_ids=inputs,
113
  generation_config=generation_config,
114
- return_dict_in_generate=False
115
  )
116
 
117
- # Remove o prompt da saída
118
  prompt_text = tokenizer.decode(inputs[0], skip_special_tokens=True)
119
- decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
120
- generated_only = decoded_text[len(prompt_text):].strip()
121
-
122
- json_result = extract_json(generated_only)
123
 
 
124
  duration = time.time() - start_time
125
  log.info(f"✅ JSON extraído em {duration:.2f}s")
126
- log.info(json_result)
127
-
128
- # Limpeza
129
- del outputs, inputs
130
- gc.collect()
131
-
132
- return json_result
133
 
134
  def extract_json(text):
135
- match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
136
  if match:
137
- json_str = match.group(0)
138
- # Correções mínimas
139
- json_str = re.sub(r"'\s*:\s*'([^']*)'", r'": "\1"', json_str)
140
- json_str = re.sub(r"'", '"', json_str)
141
- json_str = re.sub(r"\bTrue\b", "true", json_str)
142
- json_str = re.sub(r"\bFalse\b", "false", json_str)
143
- return json_str.strip()
144
- return text.strip()
145
-
146
- # -------- ENDPOINT --------
147
- @app.get("/filter")
148
- def get_filter(
149
- title: str = Query(..., description="News title"),
150
- content: str = Query(..., description="News content")
151
- ):
152
- try:
153
- json_output = infer_filter(title, content)
154
- try:
155
- parsed = json.loads(json_output)
156
- return {"result": parsed}
157
- except json.JSONDecodeError:
158
- log.error("❌ JSON inválido ao fazer parse.")
159
- return {"result": json_output, "warning": "Raw JSON string returned due to parse error"}
160
- except Exception as e:
161
- log.exception("❌ Erro inesperado:")
162
- raise HTTPException(status_code=500, detail="Erro interno durante a inferência.")
163
-
164
- # -------- WARMUP --------
165
- @app.on_event("startup")
166
- async def warmup():
167
- log.info("🔥 Warmup iniciado...")
168
- try:
169
- infer_filter("Test title", "Test content")
170
- log.info("✅ Warmup concluído.")
171
- except Exception as e:
172
- log.warning(f"⚠️ Warmup falhou: {e}")
 
4
  import time
5
  import logging
6
  import os
 
 
7
  import gc
8
  import json
9
+ from transformers import AutoTokenizer, GenerationConfig
10
+ from peft import AutoPeftModelForCausalLM
11
+ from unsloth.chat_templates import get_chat_template
12
+ from unsloth import FastLanguageModel
13
 
14
+ # -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
15
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
  os.environ["OMP_NUM_THREADS"] = "2"
17
  os.environ["MKL_NUM_THREADS"] = "2"
 
22
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
23
  log = logging.getLogger("news-filter")
24
 
25
+ # -------- MODELO --------
26
+ model_name = "habulaj/filterinstruct3b"
27
  log.info("🚀 Carregando modelo e tokenizer...")
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
30
+ tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
31
+
32
  if tokenizer.pad_token is None:
33
  tokenizer.pad_token = tokenizer.eos_token
34
 
 
 
 
 
 
 
 
35
  model = AutoPeftModelForCausalLM.from_pretrained(
36
  model_name,
37
  device_map="cpu",
38
  torch_dtype=torch.bfloat16,
39
  low_cpu_mem_usage=True,
 
40
  trust_remote_code=True
41
  )
42
+ FastLanguageModel.for_inference(model)
43
  model.eval()
44
+ log.info("✅ Modelo carregado (modo eval).")
45
 
 
46
  generation_config = GenerationConfig(
47
  max_new_tokens=128,
48
+ temperature=1.0,
49
+ do_sample=False,
50
+ num_beams=1,
51
  use_cache=True,
52
  eos_token_id=tokenizer.eos_token_id,
53
  pad_token_id=tokenizer.eos_token_id,
54
  repetition_penalty=1.1,
55
+ length_penalty=1.0
56
  )
57
 
58
  # -------- FASTAPI --------
 
62
  def read_root():
63
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
64
 
65
+ @app.get("/filter")
66
+ def get_filter(
67
+ title: str = Query(..., description="News title"),
68
+ content: str = Query(..., description="News content")
69
+ ):
70
+ try:
71
+ result = infer_filter(title, content)
72
+ try:
73
+ return {"result": json.loads(result)}
74
+ except json.JSONDecodeError:
75
+ return {"result": result, "warning": "Returned as string due to JSON parsing error"}
76
+ except HTTPException as he:
77
+ raise he
78
+ except Exception as e:
79
+ log.exception("❌ Erro inesperado:")
80
+ raise HTTPException(status_code=500, detail="Internal server error during inference.")
81
+
82
+ @app.on_event("startup")
83
+ async def warmup():
84
+ log.info("🔥 Executando warmup...")
85
+ try:
86
+ infer_filter("Test title", "Test content")
87
+ log.info("✅ Warmup concluído.")
88
+ except Exception as e:
89
+ log.warning(f"⚠️ Warmup falhou: {e}")
90
 
91
+ # -------- INFERÊNCIA --------
92
+ def infer_filter(title, content):
93
  messages = [
94
  {
95
  "role": "user",
 
103
  },
104
  {
105
  "role": "assistant",
106
+ "content": '{ "death_related": false, "relevance": "high", "global_interest": true, "entity_type": "movie", "entity_name": "Star Wars", "breaking_news": true, "has_video_content": false }'
107
+ },
108
+ {
109
+ "role": "user",
110
+ "content": """Analyze the news title and content, and return the filters in JSON format with the defined fields.
111
+
112
+ Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
113
+
114
+ Title: "Legendary Musician Carlos Mendes Dies at 78"
115
+ Content: "Carlos Mendes, the internationally acclaimed Brazilian guitarist and composer known for blending traditional bossa nova with modern jazz, has died at the age of 78."
116
+ """
117
+ },
118
+ {
119
+ "role": "assistant",
120
+ "content": '{ "death_related": true, "relevance": "high", "global_interest": true, "entity_type": "person", "entity_name": "Carlos Mendes", "breaking_news": true, "has_video_content": false }'
121
  },
122
  {
123
  "role": "user",
 
131
  }
132
  ]
133
 
134
+ log.info(f"🧠 Inferência iniciada para: {title}")
135
+ start_time = time.time()
136
+
137
  inputs = tokenizer.apply_chat_template(
138
  messages,
139
  tokenize=True,
140
  add_generation_prompt=True,
141
+ return_tensors="pt",
142
  ).to("cpu")
143
 
144
  with torch.no_grad(), torch.inference_mode():
145
  outputs = model.generate(
146
  input_ids=inputs,
147
  generation_config=generation_config,
 
148
  )
149
 
 
150
  prompt_text = tokenizer.decode(inputs[0], skip_special_tokens=True)
151
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
152
+ generated = full_output[len(prompt_text):].strip()
 
 
153
 
154
+ json_str = extract_json(generated)
155
  duration = time.time() - start_time
156
  log.info(f"✅ JSON extraído em {duration:.2f}s")
157
+ return json_str
 
 
 
 
 
 
158
 
159
  def extract_json(text):
160
+ match = re.search(r'\{.*?\}', text, flags=re.DOTALL)
161
  if match:
162
+ return match.group(0)
163
+ return text