habulaj commited on
Commit
7440293
·
verified ·
1 Parent(s): 9112884

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -115
app.py CHANGED
@@ -1,128 +1,198 @@
1
  import gradio as gr
2
  from peft import AutoPeftModelForCausalLM
3
- from transformers import AutoTokenizer
4
  import torch
5
  import re
6
  import json
 
 
 
 
7
 
8
- # Configuração global para usar CPU
9
- device = "cpu"
10
- torch.set_default_device(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Carrega o modelo e tokenizer uma vez no início
13
- print("Carregando modelo e tokenizer...")
14
  model = AutoPeftModelForCausalLM.from_pretrained(
15
- "habulaj/filterinstruct180",
16
- device_map=device,
17
- torch_dtype=torch.float32, # Usa float32 para CPU
18
- load_in_4bit=False, # Desabilita quantização para CPU
 
 
19
  )
20
 
21
- tokenizer = AutoTokenizer.from_pretrained("habulaj/filterinstruct180")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Configura o chat template
24
- tokenizer.chat_template = """{% for message in messages %}
25
- {%- if message['role'] == 'user' %}
26
- {%- if loop.first %}
27
- <|begin_of_text|><|start_header_id|>user<|end_header_id|>
28
 
29
- {{ message['content'] }}<|eot_id|>
30
- {%- else %}
31
- <|start_header_id|>user<|end_header_id|>
32
 
33
- {{ message['content'] }}<|eot_id|>
34
- {%- endif %}
35
- {%- elif message['role'] == 'assistant' %}
36
- <|start_header_id|>assistant<|end_header_id|>
37
 
38
- {{ message['content'] }}<|eot_id|>
39
- {%- endif %}
40
- {%- endfor %}
41
- {%- if add_generation_prompt %}
42
- <|start_header_id|>assistant<|end_header_id|>
43
 
44
- {%- endif %}"""
45
 
46
  def extract_json(text):
47
- """Extrai apenas o JSON da resposta"""
48
- match = re.search(r'\{.*\}', text, flags=re.DOTALL)
49
  if match:
50
- return match.group(0)
 
 
 
 
 
 
 
 
51
  return text
52
 
53
- def analyze_news(title, content):
54
- """Função principal de análise de notícias"""
55
- try:
56
- # Prepara a mensagem
57
- messages = [
58
- {
59
- "role": "user",
60
- "content": f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
61
 
62
- Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
63
 
64
- Title: "{title}"
65
- Content: "{content}"
66
- """
67
- }
68
- ]
69
-
70
- # Aplica o template e tokeniza
71
- inputs = tokenizer.apply_chat_template(
72
- messages,
73
- tokenize=True,
74
- add_generation_prompt=True,
75
- return_tensors="pt",
 
 
 
 
 
 
 
 
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Gera a resposta
79
- with torch.no_grad():
80
- outputs = model.generate(
81
- input_ids=inputs,
82
- max_new_tokens=200,
83
- use_cache=True,
84
- temperature=1.0,
85
- min_p=0.1,
86
- pad_token_id=tokenizer.eos_token_id,
87
- do_sample=True,
88
- )
89
-
90
- # Decode input (prompt)
91
- prompt_text = tokenizer.decode(inputs[0], skip_special_tokens=False)
92
-
93
- # Decode output (prompt + resposta)
94
- decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
95
-
96
- # Geração pura (remove o prompt)
97
- generated_only = decoded_text[len(prompt_text):].strip()
98
-
99
- # Extrai só o JSON
100
- json_result = extract_json(generated_only)
101
 
102
- # Tenta validar o JSON
103
- try:
104
- parsed_json = json.loads(json_result)
105
- return json.dumps(parsed_json, indent=2, ensure_ascii=False)
106
- except json.JSONDecodeError:
107
- return json_result
 
 
 
 
 
 
 
108
 
109
  except Exception as e:
110
- return f"Erro durante a análise: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # Interface Gradio
113
  def create_interface():
114
  with gr.Blocks(
115
- title="Analisador de Notícias",
116
  theme=gr.themes.Soft(),
117
  css="""
118
  .gradio-container {
119
  max-width: 1200px !important;
120
  }
 
 
 
 
 
 
121
  """
122
  ) as demo:
123
 
124
- gr.Markdown("# 📰 Analisador de Notícias")
125
- gr.Markdown("Insira o título e conteúdo da notícia para obter os filtros em formato JSON.")
126
 
127
  with gr.Row():
128
  with gr.Column(scale=1):
@@ -141,11 +211,12 @@ def create_interface():
141
  analyze_btn = gr.Button("🔍 Analisar Notícia", variant="primary")
142
 
143
  # Exemplos predefinidos
144
- gr.Markdown("### Exemplos:")
145
 
146
- example_btn1 = gr.Button("📻 Exemplo: Músico", size="sm")
147
- example_btn2 = gr.Button(" Exemplo: Esporte", size="sm")
148
- example_btn3 = gr.Button("💼 Exemplo: Negócios", size="sm")
 
149
 
150
  with gr.Column(scale=1):
151
  output = gr.Textbox(
@@ -155,23 +226,24 @@ def create_interface():
155
  show_copy_button=True
156
  )
157
 
158
- gr.Markdown("### Status:")
159
  status = gr.Textbox(
160
  label="Status da Análise",
161
- value="Aguardando entrada...",
162
  interactive=False
163
  )
164
-
165
- # Função para atualizar status
166
- def update_status_and_analyze(title, content):
167
- if not title.strip() or not content.strip():
168
- return "❌ Por favor, preencha tanto o título quanto o conteúdo.", "Erro: Campos obrigatórios não preenchidos."
169
-
170
- try:
171
- result = analyze_news(title, content)
172
- return f"✅ Análise concluída com sucesso!", result
173
- except Exception as e:
174
- return f"❌ Erro na análise: {str(e)}", f"Erro: {str(e)}"
 
 
175
 
176
  # Exemplos predefinidos
177
  def load_example_1():
@@ -194,7 +266,7 @@ def create_interface():
194
 
195
  # Event handlers
196
  analyze_btn.click(
197
- fn=update_status_and_analyze,
198
  inputs=[title_input, content_input],
199
  outputs=[status, output]
200
  )
@@ -215,24 +287,28 @@ def create_interface():
215
  )
216
 
217
  # Informações adicionais
218
- with gr.Accordion("ℹ️ Informações", open=False):
219
  gr.Markdown("""
220
- **Como usar:**
221
- 1. Insira o título da notícia
222
- 2. Insira o conteúdo da notícia
223
- 3. Clique em "Analisar Notícia"
224
- 4. O resultado será exibido em formato JSON
 
225
 
226
- **Notas:**
227
- - O modelo está rodando em CPU
228
- - O processamento pode levar alguns segundos
229
- - Use os exemplos predefinidos para testar rapidamente
230
  """)
231
 
232
  return demo
233
 
234
  if __name__ == "__main__":
235
- print("Iniciando interface Gradio...")
 
 
 
236
  demo = create_interface()
237
  demo.launch(
238
  share=False,
 
1
  import gradio as gr
2
  from peft import AutoPeftModelForCausalLM
3
+ from transformers import AutoTokenizer, GenerationConfig
4
  import torch
5
  import re
6
  import json
7
+ import time
8
+ import logging
9
+ import os
10
+ import gc
11
 
12
+ # -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+ os.environ["OMP_NUM_THREADS"] = "2"
15
+ os.environ["MKL_NUM_THREADS"] = "2"
16
+ torch.set_num_threads(2)
17
+ torch.set_num_interop_threads(1)
18
+
19
+ # -------- LOGGING CONFIG --------
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
21
+ log = logging.getLogger("news-filter-gradio")
22
+
23
+ # -------- LOAD MODEL --------
24
+ model_name = "habulaj/filterinstruct180"
25
+ log.info("🚀 Carregando modelo e tokenizer...")
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ model_name,
29
+ use_fast=True,
30
+ padding_side="left"
31
+ )
32
+
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
 
 
 
36
  model = AutoPeftModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ device_map="cpu",
39
+ torch_dtype=torch.bfloat16,
40
+ low_cpu_mem_usage=True,
41
+ use_cache=True,
42
+ trust_remote_code=True
43
  )
44
 
45
+ model.eval()
46
+ log.info("✅ Modelo carregado (eval mode).")
47
+
48
+ generation_config = GenerationConfig(
49
+ max_new_tokens=128,
50
+ temperature=1.0,
51
+ do_sample=False,
52
+ num_beams=1,
53
+ use_cache=True,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ no_repeat_ngram_size=2,
57
+ repetition_penalty=1.1,
58
+ length_penalty=1.0
59
+ )
60
 
61
+ def build_chat_prompt(title, content):
62
+ """Constrói o prompt do chat"""
63
+ return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
 
 
64
 
65
+ Analyze the news title and content, and return the filters in JSON format with the defined fields.
 
 
66
 
67
+ Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
 
 
 
68
 
69
+ Title: "{title}"
70
+ Content: "{content}"
71
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 
72
 
73
+ """
74
 
75
  def extract_json(text):
76
+ """Extrai e limpa o JSON da resposta"""
77
+ match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
78
  if match:
79
+ json_text = match.group(0)
80
+
81
+ # Conversões comuns
82
+ json_text = re.sub(r"'", '"', json_text)
83
+ json_text = re.sub(r'\bTrue\b', 'true', json_text)
84
+ json_text = re.sub(r'\bFalse\b', 'false', json_text)
85
+ json_text = re.sub(r",\s*}", "}", json_text)
86
+ json_text = re.sub(r",\s*]", "]", json_text)
87
+ return json_text.strip()
88
  return text
89
 
90
+ def infer_filter(title, content):
91
+ """Função principal de inferência otimizada"""
92
+ log.info(f"🧠 Inferência iniciada para: {title}")
93
+ start_time = time.time()
 
 
 
 
94
 
95
+ chat_prompt = build_chat_prompt(title, content)
96
 
97
+ inputs = tokenizer(
98
+ chat_prompt,
99
+ return_tensors="pt",
100
+ truncation=True,
101
+ max_length=512,
102
+ padding=False,
103
+ add_special_tokens=False
104
+ )
105
+
106
+ input_ids = inputs.input_ids
107
+ attention_mask = inputs.attention_mask
108
+
109
+ with torch.no_grad(), torch.inference_mode():
110
+ outputs = model.generate(
111
+ input_ids=input_ids,
112
+ attention_mask=attention_mask,
113
+ generation_config=generation_config,
114
+ num_return_sequences=1,
115
+ output_scores=False,
116
+ return_dict_in_generate=False
117
  )
118
+
119
+ generated_tokens = outputs[0][len(input_ids[0]):]
120
+ generated = tokenizer.decode(
121
+ generated_tokens,
122
+ skip_special_tokens=True,
123
+ clean_up_tokenization_spaces=True
124
+ )
125
+
126
+ log.info("📤 Resultado gerado:")
127
+ log.info(generated)
128
+
129
+ json_result = extract_json(generated)
130
+
131
+ duration = time.time() - start_time
132
+ log.info(f"✅ JSON extraído em {duration:.2f}s")
133
+
134
+ # Limpeza de memória
135
+ del outputs, generated_tokens, inputs
136
+ gc.collect()
137
+
138
+ return json_result, duration
139
+
140
+ def analyze_news(title, content):
141
+ """Função principal de análise de notícias para Gradio"""
142
+ try:
143
+ if not title.strip() or not content.strip():
144
+ return "❌ Por favor, preencha tanto o título quanto o conteúdo.", "Erro: Campos obrigatórios não preenchidos."
145
 
146
+ json_result, duration = infer_filter(title, content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ if json_result:
149
+ # Tenta validar e formatar o JSON
150
+ try:
151
+ parsed_json = json.loads(json_result)
152
+ formatted_json = json.dumps(parsed_json, indent=2, ensure_ascii=False)
153
+ status = f"✅ Análise concluída em {duration:.2f}s"
154
+ return status, formatted_json
155
+ except json.JSONDecodeError as e:
156
+ log.error(f"❌ Erro ao parsear JSON: {e}")
157
+ status = f"⚠️ JSON retornado como string devido a erro de parsing ({duration:.2f}s)"
158
+ return status, json_result
159
+ else:
160
+ return "❌ Não foi possível extrair JSON da resposta do modelo.", "Erro: Falha na extração do JSON."
161
 
162
  except Exception as e:
163
+ log.exception("❌ Erro inesperado:")
164
+ return f"❌ Erro durante a análise: {str(e)}", f"Erro: {str(e)}"
165
+
166
+ # -------- WARMUP --------
167
+ def warmup_model():
168
+ """Executa warmup do modelo"""
169
+ log.info("🔥 Executando warmup...")
170
+ try:
171
+ infer_filter("Test title", "Test content")
172
+ log.info("✅ Warmup concluído.")
173
+ except Exception as e:
174
+ log.warning(f"⚠️ Warmup falhou: {e}")
175
 
176
  # Interface Gradio
177
  def create_interface():
178
  with gr.Blocks(
179
+ title="Analisador de Notícias - Otimizado",
180
  theme=gr.themes.Soft(),
181
  css="""
182
  .gradio-container {
183
  max-width: 1200px !important;
184
  }
185
+ .performance-info {
186
+ background: #f0f9ff;
187
+ padding: 10px;
188
+ border-radius: 5px;
189
+ margin: 10px 0;
190
+ }
191
  """
192
  ) as demo:
193
 
194
+ gr.Markdown("# 📰 Analisador de Notícias - Otimizado")
195
+ gr.Markdown("Versão otimizada com técnicas de alto desempenho para CPU")
196
 
197
  with gr.Row():
198
  with gr.Column(scale=1):
 
211
  analyze_btn = gr.Button("🔍 Analisar Notícia", variant="primary")
212
 
213
  # Exemplos predefinidos
214
+ gr.Markdown("### Exemplos Rápidos:")
215
 
216
+ with gr.Row():
217
+ example_btn1 = gr.Button("📻 Músico", size="sm")
218
+ example_btn2 = gr.Button(" Esporte", size="sm")
219
+ example_btn3 = gr.Button("💼 Negócios", size="sm")
220
 
221
  with gr.Column(scale=1):
222
  output = gr.Textbox(
 
226
  show_copy_button=True
227
  )
228
 
 
229
  status = gr.Textbox(
230
  label="Status da Análise",
231
+ value="🟡 Aguardando entrada...",
232
  interactive=False
233
  )
234
+
235
+ # Informações de performance
236
+ with gr.Accordion("⚡ Otimizações Aplicadas", open=False):
237
+ gr.Markdown("""
238
+ **Técnicas de Otimização em CPU:**
239
+ - 🧵 Threads limitadas (OMP_NUM_THREADS=2)
240
+ - 🚫 Paralelismo de tokenizer desabilitado
241
+ - 💾 Uso otimizado de memória (bfloat16)
242
+ - 🔄 Cache de modelo ativado
243
+ - 🧹 Limpeza automática de memória
244
+ - 🎯 Modo de inferência otimizado
245
+ - 🔥 Warmup automático do modelo
246
+ """)
247
 
248
  # Exemplos predefinidos
249
  def load_example_1():
 
266
 
267
  # Event handlers
268
  analyze_btn.click(
269
+ fn=analyze_news,
270
  inputs=[title_input, content_input],
271
  outputs=[status, output]
272
  )
 
287
  )
288
 
289
  # Informações adicionais
290
+ with gr.Accordion("ℹ️ Informações Técnicas", open=False):
291
  gr.Markdown("""
292
+ **Configuração do Modelo:**
293
+ - Modelo: `habulaj/filterinstruct180`
294
+ - Formato: `torch.bfloat16` (otimizado para CPU)
295
+ - Max tokens: 128
296
+ - Beam search: Desabilitado (mais rápido)
297
+ - Cache: Ativado
298
 
299
+ **Performance:**
300
+ - Threads: 2 (OpenMP + MKL)
301
+ - Memória: Otimizada com limpeza automática
302
+ - Warmup: Executado automaticamente
303
  """)
304
 
305
  return demo
306
 
307
  if __name__ == "__main__":
308
+ # Executa warmup antes de iniciar a interface
309
+ warmup_model()
310
+
311
+ print("🚀 Iniciando interface Gradio otimizada...")
312
  demo = create_interface()
313
  demo.launch(
314
  share=False,