habulaj commited on
Commit
a15c41a
·
verified ·
1 Parent(s): 2560f7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -1,39 +1,48 @@
1
  from fastapi import FastAPI, Query, HTTPException
2
  import torch
3
  import re
 
 
4
  from transformers import AutoTokenizer
5
  from peft import AutoPeftModelForCausalLM
6
 
7
- # Carrega modelo e tokenizer da Hugging Face - LoRA fine-tuned
 
 
 
 
 
 
 
8
  model_name = "habulaj/filter"
9
- print("Carregando tokenizer e modelo (CPU)...")
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
 
12
- # Otimizações de performance
13
  model = AutoPeftModelForCausalLM.from_pretrained(
14
  model_name,
15
  device_map="cpu",
16
- torch_dtype=torch.float32, # float32 é mais rápido em CPU
17
- low_cpu_mem_usage=True, # Reduz uso de memória
18
  )
19
  model.eval()
 
20
 
21
- # Compilação do modelo para otimizar (PyTorch 2.0+)
22
  try:
23
  model = torch.compile(model, mode="reduce-overhead")
24
- print("✅ Modelo compilado com torch.compile")
25
  except Exception as e:
26
- print(f"⚠️ torch.compile não disponível: {e}")
27
 
28
  # -------- FASTAPI --------
29
  app = FastAPI(title="News Filter JSON API")
30
 
31
- # -------- ROOT ENDPOINT --------
32
  @app.get("/")
33
  def read_root():
34
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
35
 
36
- # Função para inferência otimizada
37
  def infer_filter(title, content):
38
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
39
 
@@ -43,44 +52,48 @@ Title: "{title}"
43
  Content: "{content}"
44
  """
45
 
46
- # Otimizações de tokenização
 
 
47
  inputs = tokenizer(
48
- prompt,
49
  return_tensors="pt",
50
  truncation=True,
51
- max_length=512, # Limita tamanho do input
52
- padding=False # Não faz padding desnecessário
53
  )
54
  input_ids = inputs.input_ids.to("cpu")
55
-
56
  with torch.no_grad():
57
- # Configurações otimizadas para velocidade
58
  outputs = model.generate(
59
  input_ids=input_ids,
60
- max_new_tokens=100, # Reduzido de 128 para 100
61
- temperature=1.0, # Reduzido para ser mais determinístico
62
  do_sample=True,
63
  top_p=0.9,
64
- num_beams=1, # Beam search = 1 (greedy) é mais rápido
65
- early_stopping=True, # Para quando encontrar EOS
66
  eos_token_id=tokenizer.eos_token_id,
67
  pad_token_id=tokenizer.eos_token_id,
68
  )
69
-
70
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
-
72
- # Remove prompt do output
73
  generated = decoded[len(prompt):].strip()
74
-
75
- # Extrai JSON
 
 
76
  match = re.search(r"\{.*\}", generated, re.DOTALL)
77
  if match:
78
- result = match.group(0)
79
- return result
 
 
80
  else:
 
81
  return "⚠️ Failed to extract JSON. Output:\n" + generated
82
 
83
- # -------- API ROUTE --------
84
  @app.get("/filter")
85
  def get_filter(
86
  title: str = Query(..., description="Title of the news"),
@@ -89,10 +102,10 @@ def get_filter(
89
  try:
90
  json_output = infer_filter(title, content)
91
  import json
92
- # Retorna como dados brutos (parse do JSON)
93
  return json.loads(json_output)
94
  except json.JSONDecodeError:
95
- # Se não conseguir fazer parse, retorna como string
96
  return {"raw_output": json_output}
97
  except Exception as e:
 
98
  raise HTTPException(status_code=422, detail=str(e))
 
1
  from fastapi import FastAPI, Query, HTTPException
2
  import torch
3
  import re
4
+ import time
5
+ import logging
6
  from transformers import AutoTokenizer
7
  from peft import AutoPeftModelForCausalLM
8
 
9
+ # -------- LOGGING CONFIG --------
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s [%(levelname)s] %(message)s",
13
+ )
14
+ log = logging.getLogger("news-filter")
15
+
16
+ # -------- CARREGAMENTO DE MODELO --------
17
  model_name = "habulaj/filter"
18
+ log.info("🚀 Iniciando carregamento do modelo e tokenizer...")
19
+
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ log.info("✅ Tokenizer carregado.")
22
 
 
23
  model = AutoPeftModelForCausalLM.from_pretrained(
24
  model_name,
25
  device_map="cpu",
26
+ torch_dtype=torch.float32,
27
+ low_cpu_mem_usage=True,
28
  )
29
  model.eval()
30
+ log.info("✅ Modelo carregado e em modo eval.")
31
 
 
32
  try:
33
  model = torch.compile(model, mode="reduce-overhead")
34
+ log.info("✅ Modelo compilado com torch.compile.")
35
  except Exception as e:
36
+ log.warning(f"⚠️ torch.compile indisponível: {e}")
37
 
38
  # -------- FASTAPI --------
39
  app = FastAPI(title="News Filter JSON API")
40
 
 
41
  @app.get("/")
42
  def read_root():
43
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
44
 
45
+ # -------- INFERÊNCIA --------
46
  def infer_filter(title, content):
47
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
48
 
 
52
  Content: "{content}"
53
  """
54
 
55
+ log.info(f"🧠 Iniciando inferência para notícia:\n📰 Title: {title}\n📝 Content: {content[:100]}...")
56
+ start_time = time.time()
57
+
58
  inputs = tokenizer(
59
+ prompt,
60
  return_tensors="pt",
61
  truncation=True,
62
+ max_length=512,
63
+ padding=False,
64
  )
65
  input_ids = inputs.input_ids.to("cpu")
66
+
67
  with torch.no_grad():
 
68
  outputs = model.generate(
69
  input_ids=input_ids,
70
+ max_new_tokens=100,
71
+ temperature=1.0,
72
  do_sample=True,
73
  top_p=0.9,
74
+ num_beams=1,
75
+ early_stopping=True,
76
  eos_token_id=tokenizer.eos_token_id,
77
  pad_token_id=tokenizer.eos_token_id,
78
  )
79
+
80
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
81
  generated = decoded[len(prompt):].strip()
82
+
83
+ log.info("📤 Resposta bruta decodificada:")
84
+ log.info(generated)
85
+
86
  match = re.search(r"\{.*\}", generated, re.DOTALL)
87
  if match:
88
+ json_result = match.group(0)
89
+ duration = time.time() - start_time
90
+ log.info(f"✅ JSON extraído com sucesso em {duration:.2f}s")
91
+ return json_result
92
  else:
93
+ log.warning("⚠️ Não foi possível extrair JSON.")
94
  return "⚠️ Failed to extract JSON. Output:\n" + generated
95
 
96
+ # -------- ENDPOINT --------
97
  @app.get("/filter")
98
  def get_filter(
99
  title: str = Query(..., description="Title of the news"),
 
102
  try:
103
  json_output = infer_filter(title, content)
104
  import json
 
105
  return json.loads(json_output)
106
  except json.JSONDecodeError:
107
+ log.error("❌ Erro ao fazer parse do JSON retornado.")
108
  return {"raw_output": json_output}
109
  except Exception as e:
110
+ log.exception("❌ Erro inesperado durante a inferência:")
111
  raise HTTPException(status_code=422, detail=str(e))