habulaj commited on
Commit
2c0cc11
·
verified ·
1 Parent(s): caa0753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -9
app.py CHANGED
@@ -8,12 +8,25 @@ from peft import AutoPeftModelForCausalLM
8
  model_name = "habulaj/filter"
9
  print("Carregando tokenizer e modelo (CPU)...")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
11
  model = AutoPeftModelForCausalLM.from_pretrained(
12
  model_name,
13
- device_map="cpu", # Força CPU
14
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # usa float16 se possível, senão float32
 
15
  )
16
- model.eval() # modo avaliação
 
 
 
 
 
 
 
 
 
 
17
 
18
  # -------- FASTAPI --------
19
  app = FastAPI(title="News Filter JSON API")
@@ -23,8 +36,13 @@ app = FastAPI(title="News Filter JSON API")
23
  def read_root():
24
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
25
 
26
- # Função para inferência
27
  def infer_filter(title, content):
 
 
 
 
 
28
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
29
 
30
  Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
@@ -33,17 +51,28 @@ Title: "{title}"
33
  Content: "{content}"
34
  """
35
 
36
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
37
  input_ids = inputs.input_ids.to("cpu")
38
 
39
  with torch.no_grad():
 
40
  outputs = model.generate(
41
  input_ids=input_ids,
42
- max_new_tokens=128,
43
- temperature=1.2,
44
  do_sample=True,
45
  top_p=0.9,
 
 
46
  eos_token_id=tokenizer.eos_token_id,
 
47
  )
48
 
49
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -54,7 +83,11 @@ Content: "{content}"
54
  # Extrai JSON
55
  match = re.search(r"\{.*\}", generated, re.DOTALL)
56
  if match:
57
- return match.group(0)
 
 
 
 
58
  else:
59
  return "⚠️ Failed to extract JSON. Output:\n" + generated
60
 
@@ -66,6 +99,11 @@ def get_filter(
66
  ):
67
  try:
68
  json_output = infer_filter(title, content)
69
- return {"filter": json_output}
 
 
 
 
 
70
  except Exception as e:
71
  raise HTTPException(status_code=422, detail=str(e))
 
8
  model_name = "habulaj/filter"
9
  print("Carregando tokenizer e modelo (CPU)...")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ # Otimizações de performance
13
  model = AutoPeftModelForCausalLM.from_pretrained(
14
  model_name,
15
+ device_map="cpu",
16
+ torch_dtype=torch.float32, # float32 é mais rápido em CPU
17
+ low_cpu_mem_usage=True, # Reduz uso de memória
18
  )
19
+ model.eval()
20
+
21
+ # Compilação do modelo para otimizar (PyTorch 2.0+)
22
+ try:
23
+ model = torch.compile(model, mode="reduce-overhead")
24
+ print("✅ Modelo compilado com torch.compile")
25
+ except Exception as e:
26
+ print(f"⚠️ torch.compile não disponível: {e}")
27
+
28
+ # Cache para prompts similares
29
+ prompt_cache = {}
30
 
31
  # -------- FASTAPI --------
32
  app = FastAPI(title="News Filter JSON API")
 
36
  def read_root():
37
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
38
 
39
+ # Função para inferência otimizada
40
  def infer_filter(title, content):
41
+ # Cache key simples
42
+ cache_key = hash((title[:50], content[:100]))
43
+ if cache_key in prompt_cache:
44
+ return prompt_cache[cache_key]
45
+
46
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
47
 
48
  Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
 
51
  Content: "{content}"
52
  """
53
 
54
+ # Otimizações de tokenização
55
+ inputs = tokenizer(
56
+ prompt,
57
+ return_tensors="pt",
58
+ truncation=True,
59
+ max_length=512, # Limita tamanho do input
60
+ padding=False # Não faz padding desnecessário
61
+ )
62
  input_ids = inputs.input_ids.to("cpu")
63
 
64
  with torch.no_grad():
65
+ # Configurações otimizadas para velocidade
66
  outputs = model.generate(
67
  input_ids=input_ids,
68
+ max_new_tokens=100, # Reduzido de 128 para 100
69
+ temperature=1.0, # Reduzido para ser mais determinístico
70
  do_sample=True,
71
  top_p=0.9,
72
+ num_beams=1, # Beam search = 1 (greedy) é mais rápido
73
+ early_stopping=True, # Para quando encontrar EOS
74
  eos_token_id=tokenizer.eos_token_id,
75
+ pad_token_id=tokenizer.eos_token_id,
76
  )
77
 
78
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
83
  # Extrai JSON
84
  match = re.search(r"\{.*\}", generated, re.DOTALL)
85
  if match:
86
+ result = match.group(0)
87
+ # Cache o resultado (limitado a 100 entradas)
88
+ if len(prompt_cache) < 100:
89
+ prompt_cache[cache_key] = result
90
+ return result
91
  else:
92
  return "⚠️ Failed to extract JSON. Output:\n" + generated
93
 
 
99
  ):
100
  try:
101
  json_output = infer_filter(title, content)
102
+ import json
103
+ # Retorna como dados brutos (parse do JSON)
104
+ return json.loads(json_output)
105
+ except json.JSONDecodeError:
106
+ # Se não conseguir fazer parse, retorna como string
107
+ return {"raw_output": json_output}
108
  except Exception as e:
109
  raise HTTPException(status_code=422, detail=str(e))