habulaj commited on
Commit
3988dcb
·
verified ·
1 Parent(s): 480aadb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -7,18 +7,14 @@ from transformers import AutoTokenizer
7
  from peft import AutoPeftModelForCausalLM
8
 
9
  # -------- LOGGING CONFIG --------
10
- logging.basicConfig(
11
- level=logging.INFO,
12
- format="%(asctime)s [%(levelname)s] %(message)s",
13
- )
14
  log = logging.getLogger("news-filter")
15
 
16
- # -------- CARREGAMENTO DE MODELO --------
17
  model_name = "habulaj/filter"
18
- log.info("🚀 Iniciando carregamento do modelo e tokenizer...")
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- log.info("✅ Tokenizer carregado.")
22
 
23
  model = AutoPeftModelForCausalLM.from_pretrained(
24
  model_name,
@@ -27,21 +23,22 @@ model = AutoPeftModelForCausalLM.from_pretrained(
27
  low_cpu_mem_usage=True,
28
  )
29
  model.eval()
30
- log.info("✅ Modelo carregado e em modo eval.")
31
 
32
  try:
 
33
  log.info("✅ Modelo compilado com torch.compile.")
34
  except Exception as e:
35
- log.warning(f"⚠️ torch.compile indisponível: {e}")
36
 
37
- # -------- FASTAPI --------
38
  app = FastAPI(title="News Filter JSON API")
39
 
40
  @app.get("/")
41
  def read_root():
42
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
43
 
44
- # -------- INFERÊNCIA --------
45
  def infer_filter(title, content):
46
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
47
 
@@ -51,27 +48,30 @@ Title: "{title}"
51
  Content: "{content}"
52
  """
53
 
54
- log.info(f"🧠 Iniciando inferência para notícia:\n📰 Title: {title}\n📝 Content: {content[:100]}...")
55
  start_time = time.time()
56
 
 
57
  inputs = tokenizer(
58
  prompt,
59
  return_tensors="pt",
60
  truncation=True,
61
  max_length=512,
62
- padding=False,
63
  )
64
  input_ids = inputs.input_ids.to("cpu")
 
65
 
66
  with torch.no_grad():
67
  outputs = model.generate(
68
  input_ids=input_ids,
 
69
  max_new_tokens=100,
70
  temperature=1.0,
71
- do_sample=True,
72
- top_p=0.9,
 
73
  num_beams=1,
74
- early_stopping=True,
75
  eos_token_id=tokenizer.eos_token_id,
76
  pad_token_id=tokenizer.eos_token_id,
77
  )
@@ -79,32 +79,32 @@ Content: "{content}"
79
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
  generated = decoded[len(prompt):].strip()
81
 
82
- log.info("📤 Resposta bruta decodificada:")
83
  log.info(generated)
84
 
85
  match = re.search(r"\{.*\}", generated, re.DOTALL)
86
  if match:
87
- json_result = match.group(0)
88
  duration = time.time() - start_time
89
- log.info(f"✅ JSON extraído com sucesso em {duration:.2f}s")
 
90
  return json_result
91
  else:
92
- log.warning("⚠️ Não foi possível extrair JSON.")
93
  return "⚠️ Failed to extract JSON. Output:\n" + generated
94
 
95
- # -------- ENDPOINT --------
96
  @app.get("/filter")
97
  def get_filter(
98
- title: str = Query(..., description="Title of the news"),
99
- content: str = Query(..., description="Content of the news")
100
  ):
101
  try:
102
  json_output = infer_filter(title, content)
103
  import json
104
  return json.loads(json_output)
105
  except json.JSONDecodeError:
106
- log.error("❌ Erro ao fazer parse do JSON retornado.")
107
  return {"raw_output": json_output}
108
  except Exception as e:
109
- log.exception("❌ Erro inesperado durante a inferência:")
110
  raise HTTPException(status_code=422, detail=str(e))
 
7
  from peft import AutoPeftModelForCausalLM
8
 
9
  # -------- LOGGING CONFIG --------
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
11
  log = logging.getLogger("news-filter")
12
 
13
+ # -------- LOAD MODEL --------
14
  model_name = "habulaj/filter"
15
+ log.info("🚀 Carregando modelo e tokenizer...")
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
18
 
19
  model = AutoPeftModelForCausalLM.from_pretrained(
20
  model_name,
 
23
  low_cpu_mem_usage=True,
24
  )
25
  model.eval()
26
+ log.info("✅ Modelo carregado (eval mode).")
27
 
28
  try:
29
+ model = torch.compile(model, mode="reduce-overhead")
30
  log.info("✅ Modelo compilado com torch.compile.")
31
  except Exception as e:
32
+ log.warning(f"⚠️ torch.compile não disponível: {e}")
33
 
34
+ # -------- FASTAPI INIT --------
35
  app = FastAPI(title="News Filter JSON API")
36
 
37
  @app.get("/")
38
  def read_root():
39
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
40
 
41
+ # -------- INFERENCE --------
42
  def infer_filter(title, content):
43
  prompt = f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
44
 
 
48
  Content: "{content}"
49
  """
50
 
51
+ log.info(f"🧠 Inferência iniciada para: {title}")
52
  start_time = time.time()
53
 
54
+ # Tokenização + attention mask
55
  inputs = tokenizer(
56
  prompt,
57
  return_tensors="pt",
58
  truncation=True,
59
  max_length=512,
60
+ padding=True,
61
  )
62
  input_ids = inputs.input_ids.to("cpu")
63
+ attention_mask = inputs.attention_mask.to("cpu")
64
 
65
  with torch.no_grad():
66
  outputs = model.generate(
67
  input_ids=input_ids,
68
+ attention_mask=attention_mask,
69
  max_new_tokens=100,
70
  temperature=1.0,
71
+ do_sample=False, # Greedy decoding
72
+ top_k=50, # Razoável para limitar
73
+ no_repeat_ngram_size=2,
74
  num_beams=1,
 
75
  eos_token_id=tokenizer.eos_token_id,
76
  pad_token_id=tokenizer.eos_token_id,
77
  )
 
79
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
  generated = decoded[len(prompt):].strip()
81
 
82
+ log.info("📤 Resultado gerado:")
83
  log.info(generated)
84
 
85
  match = re.search(r"\{.*\}", generated, re.DOTALL)
86
  if match:
 
87
  duration = time.time() - start_time
88
+ json_result = match.group(0)
89
+ log.info(f"✅ JSON extraído em {duration:.2f}s")
90
  return json_result
91
  else:
92
+ log.warning("⚠️ Falha ao extrair JSON.")
93
  return "⚠️ Failed to extract JSON. Output:\n" + generated
94
 
95
+ # -------- API --------
96
  @app.get("/filter")
97
  def get_filter(
98
+ title: str = Query(..., description="News title"),
99
+ content: str = Query(..., description="News content")
100
  ):
101
  try:
102
  json_output = infer_filter(title, content)
103
  import json
104
  return json.loads(json_output)
105
  except json.JSONDecodeError:
106
+ log.error("❌ Erro ao fazer parse do JSON.")
107
  return {"raw_output": json_output}
108
  except Exception as e:
109
+ log.exception("❌ Erro inesperado:")
110
  raise HTTPException(status_code=422, detail=str(e))