habulaj commited on
Commit
2560f7d
·
verified ·
1 Parent(s): 9537362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -3,10 +3,9 @@ import torch
3
  import re
4
  from transformers import AutoTokenizer
5
  from peft import AutoPeftModelForCausalLM
6
- import json
7
 
8
  # Carrega modelo e tokenizer da Hugging Face - LoRA fine-tuned
9
- model_name = "habulaj/filterinstruct"
10
  print("Carregando tokenizer e modelo (CPU)...")
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
@@ -29,6 +28,7 @@ except Exception as e:
29
  # -------- FASTAPI --------
30
  app = FastAPI(title="News Filter JSON API")
31
 
 
32
  @app.get("/")
33
  def read_root():
34
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
@@ -43,37 +43,44 @@ Title: "{title}"
43
  Content: "{content}"
44
  """
45
 
 
46
  inputs = tokenizer(
47
- prompt,
48
  return_tensors="pt",
49
  truncation=True,
50
- max_length=512,
51
- padding=False
52
  )
53
  input_ids = inputs.input_ids.to("cpu")
54
-
55
  with torch.no_grad():
 
56
  outputs = model.generate(
57
  input_ids=input_ids,
58
- max_new_tokens=100,
59
- temperature=1.0,
60
  do_sample=True,
61
  top_p=0.9,
62
- num_beams=1,
63
- early_stopping=True,
64
  eos_token_id=tokenizer.eos_token_id,
65
  pad_token_id=tokenizer.eos_token_id,
66
  )
67
-
68
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
69
  generated = decoded[len(prompt):].strip()
70
-
 
71
  match = re.search(r"\{.*\}", generated, re.DOTALL)
72
  if match:
73
- return match.group(0)
 
74
  else:
75
  return "⚠️ Failed to extract JSON. Output:\n" + generated
76
 
 
77
  @app.get("/filter")
78
  def get_filter(
79
  title: str = Query(..., description="Title of the news"),
@@ -81,8 +88,11 @@ def get_filter(
81
  ):
82
  try:
83
  json_output = infer_filter(title, content)
 
 
84
  return json.loads(json_output)
85
  except json.JSONDecodeError:
 
86
  return {"raw_output": json_output}
87
  except Exception as e:
88
  raise HTTPException(status_code=422, detail=str(e))
 
3
  import re
4
  from transformers import AutoTokenizer
5
  from peft import AutoPeftModelForCausalLM
 
6
 
7
  # Carrega modelo e tokenizer da Hugging Face - LoRA fine-tuned
8
+ model_name = "habulaj/filter"
9
  print("Carregando tokenizer e modelo (CPU)...")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
 
28
  # -------- FASTAPI --------
29
  app = FastAPI(title="News Filter JSON API")
30
 
31
+ # -------- ROOT ENDPOINT --------
32
  @app.get("/")
33
  def read_root():
34
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
 
43
  Content: "{content}"
44
  """
45
 
46
+ # Otimizações de tokenização
47
  inputs = tokenizer(
48
+ prompt,
49
  return_tensors="pt",
50
  truncation=True,
51
+ max_length=512, # Limita tamanho do input
52
+ padding=False # Não faz padding desnecessário
53
  )
54
  input_ids = inputs.input_ids.to("cpu")
55
+
56
  with torch.no_grad():
57
+ # Configurações otimizadas para velocidade
58
  outputs = model.generate(
59
  input_ids=input_ids,
60
+ max_new_tokens=100, # Reduzido de 128 para 100
61
+ temperature=1.0, # Reduzido para ser mais determinístico
62
  do_sample=True,
63
  top_p=0.9,
64
+ num_beams=1, # Beam search = 1 (greedy) é mais rápido
65
+ early_stopping=True, # Para quando encontrar EOS
66
  eos_token_id=tokenizer.eos_token_id,
67
  pad_token_id=tokenizer.eos_token_id,
68
  )
69
+
70
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ # Remove prompt do output
73
  generated = decoded[len(prompt):].strip()
74
+
75
+ # Extrai JSON
76
  match = re.search(r"\{.*\}", generated, re.DOTALL)
77
  if match:
78
+ result = match.group(0)
79
+ return result
80
  else:
81
  return "⚠️ Failed to extract JSON. Output:\n" + generated
82
 
83
+ # -------- API ROUTE --------
84
  @app.get("/filter")
85
  def get_filter(
86
  title: str = Query(..., description="Title of the news"),
 
88
  ):
89
  try:
90
  json_output = infer_filter(title, content)
91
+ import json
92
+ # Retorna como dados brutos (parse do JSON)
93
  return json.loads(json_output)
94
  except json.JSONDecodeError:
95
+ # Se não conseguir fazer parse, retorna como string
96
  return {"raw_output": json_output}
97
  except Exception as e:
98
  raise HTTPException(status_code=422, detail=str(e))