habulaj commited on
Commit
cda1138
·
verified ·
1 Parent(s): 8efc71c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -53
app.py CHANGED
@@ -3,49 +3,72 @@ import torch
3
  import re
4
  import time
5
  import logging
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
 
 
 
 
 
 
 
7
 
8
  # -------- LOGGING CONFIG --------
9
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.FileHandler("inference.log"), logging.StreamHandler()])
10
  log = logging.getLogger("news-filter")
11
 
12
  # -------- LOAD MODEL --------
13
  model_name = "habulaj/filterinstruct"
14
  log.info("🚀 Carregando modelo e tokenizer...")
15
 
16
- # Configuração para quantização em 8-bit para CPU
17
- # BitsAndBytesConfig é primariamente para GPU, mas pode ser usado para indicar a intenção de quantização.
18
- # Para CPU, a quantização real pode depender do suporte do modelo e da biblioteca transformers.
19
- quantization_config = BitsAndBytesConfig(
20
- load_in_8bit=True,
21
- # bnb_4bit_compute_dtype=torch.float32, # Não é necessário para 8-bit e pode causar problemas em CPU
22
- # bnb_4bit_quant_type="nf4", # Não é necessário para 8-bit
23
- # bnb_4bit_use_double_quant=True, # Não é necessário para 8-bit
24
  )
25
 
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
27
 
28
- # Carregar o modelo com a configuração de quantização
29
- # Se o modelo não suportar 8-bit em CPU, ele fará fallback para float32.
30
- model = AutoModelForCausalLM.from_pretrained(
31
  model_name,
32
  device_map="cpu",
33
- torch_dtype=torch.float32, # Manter float32 para garantir compatibilidade com CPU
34
  low_cpu_mem_usage=True,
35
- quantization_config=quantization_config, # Aplicar configuração de quantização
 
36
  )
 
37
  model.eval()
38
  log.info("✅ Modelo carregado (eval mode).")
39
 
40
- # Otimização para CPU: Habilitar formato de memória contígua
41
- try:
42
- torch.backends.cpu.enable_contiguous_memory_format()
43
- log.info("✅ Formato de memória contígua para CPU habilitado.")
44
- except Exception as e:
45
- log.warning(f"⚠️ Não foi possível habilitar formato de memória contígua para CPU: {e}")
 
 
 
 
 
 
 
46
 
 
47
  try:
48
- model = torch.compile(model, mode="reduce-overhead")
 
 
 
 
 
49
  log.info("✅ Modelo compilado com torch.compile.")
50
  except Exception as e:
51
  log.warning(f"⚠️ torch.compile não disponível: {e}")
@@ -57,50 +80,72 @@ app = FastAPI(title="News Filter JSON API")
57
  def read_root():
58
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
59
 
60
- # -------- INFERENCE --------
61
  def infer_filter(title, content):
62
- prompt = f"""Analyze the news title and content, and return the filters in strict JSON format.\n\nUse only double quotes for all property names and string values. Use lowercase `true` and `false` for booleans. Do not include any explanations, labels, or comments.\n\nTitle: "{title}"\nContent: "{content}"\n"""
63
-
 
 
 
 
64
  log.info(f"🧠 Inferência iniciada para: {title}")
65
  start_time = time.time()
66
-
 
67
  inputs = tokenizer(
68
  prompt,
69
  return_tensors="pt",
70
  truncation=True,
71
- max_length=512,
72
- padding=True,
 
73
  )
74
- input_ids = inputs.input_ids.to("cpu")
75
- attention_mask = inputs.attention_mask.to("cpu")
76
-
 
 
77
  with torch.no_grad():
78
- outputs = model.generate(
79
- input_ids=input_ids,
80
- attention_mask=attention_mask,
81
- max_new_tokens=100,
82
- temperature=1.0,
83
- do_sample=False,
84
- top_k=50,
85
- no_repeat_ngram_size=2,
86
- num_beams=1,
87
- eos_token_id=tokenizer.eos_token_id,
88
- pad_token_id=tokenizer.eos_token_id,
89
- )
90
-
91
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
92
- generated = decoded[len(prompt):].strip()
93
-
 
 
 
 
94
  log.info("📤 Resultado gerado:")
95
  log.info(generated)
96
-
97
- match = re.search(r"\\{.*\\}", generated, re.DOTALL)
 
98
  if match:
99
  duration = time.time() - start_time
 
100
  log.info(f"✅ JSON extraído em {duration:.2f}s")
101
- return match.group(0)
 
 
 
 
 
102
  else:
103
  log.warning("⚠️ Falha ao extrair JSON.")
 
 
 
104
  raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.")
105
 
106
  # -------- API --------
@@ -114,7 +159,19 @@ def get_filter(
114
  import json
115
  return json.loads(json_output)
116
  except HTTPException as he:
117
- raise he # já tratado
118
  except Exception as e:
119
  log.exception("❌ Erro inesperado:")
120
- raise HTTPException(status_code=404, detail="Invalid or malformed JSON output from model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import re
4
  import time
5
  import logging
6
+ import os
7
+ from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig
8
+ from peft import AutoPeftModelForCausalLM
9
+ import gc
10
+
11
+ # -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
+ os.environ["OMP_NUM_THREADS"] = "2" # Ajuste para seus 2 vcpus
14
+ os.environ["MKL_NUM_THREADS"] = "2"
15
+ torch.set_num_threads(2)
16
+ torch.set_num_interop_threads(1)
17
 
18
  # -------- LOGGING CONFIG --------
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
20
  log = logging.getLogger("news-filter")
21
 
22
  # -------- LOAD MODEL --------
23
  model_name = "habulaj/filterinstruct"
24
  log.info("🚀 Carregando modelo e tokenizer...")
25
 
26
+ # Tokenizer otimizado
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ model_name,
29
+ use_fast=True, # Usa tokenizer fast se disponível
30
+ padding_side="left" # Padding à esquerda para melhor performance
 
 
 
31
  )
32
 
33
+ # Configurar pad_token se não existir
34
+ if tokenizer.pad_token is None:
35
+ tokenizer.pad_token = tokenizer.eos_token
36
 
37
+ # Modelo otimizado
38
+ model = AutoPeftModelForCausalLM.from_pretrained(
 
39
  model_name,
40
  device_map="cpu",
41
+ torch_dtype=torch.bfloat16, # bfloat16 é mais rápido que float32 em CPU moderna
42
  low_cpu_mem_usage=True,
43
+ use_cache=True, # Cache interno do modelo
44
+ trust_remote_code=True
45
  )
46
+
47
  model.eval()
48
  log.info("✅ Modelo carregado (eval mode).")
49
 
50
+ # Configuração de geração otimizada
51
+ generation_config = GenerationConfig(
52
+ max_new_tokens=100,
53
+ temperature=1.0,
54
+ do_sample=False,
55
+ num_beams=1,
56
+ use_cache=True,
57
+ eos_token_id=tokenizer.eos_token_id,
58
+ pad_token_id=tokenizer.eos_token_id,
59
+ no_repeat_ngram_size=2,
60
+ repetition_penalty=1.1,
61
+ length_penalty=1.0
62
+ )
63
 
64
+ # Torch compile com configurações otimizadas
65
  try:
66
+ model = torch.compile(
67
+ model,
68
+ mode="reduce-overhead",
69
+ fullgraph=True,
70
+ dynamic=False
71
+ )
72
  log.info("✅ Modelo compilado com torch.compile.")
73
  except Exception as e:
74
  log.warning(f"⚠️ torch.compile não disponível: {e}")
 
80
  def read_root():
81
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
82
 
83
+ # -------- INFERENCE OTIMIZADA --------
84
  def infer_filter(title, content):
85
+ # Prompt mais conciso para reduzir tokens
86
+ prompt = f"""Analyze and return JSON filters:
87
+ Title: "{title}"
88
+ Content: "{content}"
89
+ """
90
+
91
  log.info(f"🧠 Inferência iniciada para: {title}")
92
  start_time = time.time()
93
+
94
+ # Tokenização otimizada
95
  inputs = tokenizer(
96
  prompt,
97
  return_tensors="pt",
98
  truncation=True,
99
+ max_length=384, # Reduzido de 512 para acelerar
100
+ padding=False, # Sem padding desnecessário
101
+ add_special_tokens=True,
102
  )
103
+
104
+ input_ids = inputs.input_ids
105
+ attention_mask = inputs.attention_mask
106
+
107
+ # Geração otimizada
108
  with torch.no_grad():
109
+ with torch.inference_mode(): # Modo de inferência mais rápido
110
+ outputs = model.generate(
111
+ input_ids=input_ids,
112
+ attention_mask=attention_mask,
113
+ generation_config=generation_config,
114
+ # Parâmetros adicionais de otimização
115
+ early_stopping=True,
116
+ num_return_sequences=1,
117
+ output_scores=False,
118
+ return_dict_in_generate=False,
119
+ )
120
+
121
+ # Decodificação otimizada
122
+ generated_tokens = outputs[0][len(input_ids[0]):]
123
+ generated = tokenizer.decode(
124
+ generated_tokens,
125
+ skip_special_tokens=True,
126
+ clean_up_tokenization_spaces=True
127
+ )
128
+
129
  log.info("📤 Resultado gerado:")
130
  log.info(generated)
131
+
132
+ # Regex otimizada
133
+ match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', generated, re.DOTALL)
134
  if match:
135
  duration = time.time() - start_time
136
+ json_result = match.group(0)
137
  log.info(f"✅ JSON extraído em {duration:.2f}s")
138
+
139
+ # Limpeza de memória
140
+ del outputs, generated_tokens, inputs
141
+ gc.collect()
142
+
143
+ return json_result
144
  else:
145
  log.warning("⚠️ Falha ao extrair JSON.")
146
+ # Limpeza de memória mesmo em caso de erro
147
+ del outputs, generated_tokens, inputs
148
+ gc.collect()
149
  raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.")
150
 
151
  # -------- API --------
 
159
  import json
160
  return json.loads(json_output)
161
  except HTTPException as he:
162
+ raise he
163
  except Exception as e:
164
  log.exception("❌ Erro inesperado:")
165
+ raise HTTPException(status_code=404, detail="Invalid or malformed JSON output from model.")
166
+
167
+ # -------- WARMUP (OPCIONAL) --------
168
+ @app.on_event("startup")
169
+ async def warmup():
170
+ """Faz um warmup do modelo para otimizar as primeiras execuções"""
171
+ log.info("🔥 Executando warmup...")
172
+ try:
173
+ # Exemplo simples para warmup
174
+ infer_filter("Test title", "Test content")
175
+ log.info("✅ Warmup concluído.")
176
+ except Exception as e:
177
+ log.warning(f"⚠️ Warmup falhou: {e}")