habulaj commited on
Commit
59d7833
·
verified ·
1 Parent(s): 9650ab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -94
app.py CHANGED
@@ -1,17 +1,12 @@
1
  from fastapi import FastAPI, Query, HTTPException
2
- import os
3
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
4
  import torch
5
  import re
6
  import time
7
  import logging
8
  import os
9
- import gc
10
- import json
11
  from transformers import AutoTokenizer, GenerationConfig
12
  from peft import AutoPeftModelForCausalLM
13
- from unsloth.chat_templates import get_chat_template
14
- from unsloth import FastLanguageModel
15
 
16
  # -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -20,16 +15,19 @@ os.environ["MKL_NUM_THREADS"] = "2"
20
  torch.set_num_threads(2)
21
  torch.set_num_interop_threads(1)
22
 
23
- # -------- LOGGING --------
24
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
25
  log = logging.getLogger("news-filter")
26
 
27
- # -------- MODELO --------
28
  model_name = "habulaj/filterinstruct180"
29
  log.info("🚀 Carregando modelo e tokenizer...")
30
 
31
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
32
- tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
 
 
 
33
 
34
  if tokenizer.pad_token is None:
35
  tokenizer.pad_token = tokenizer.eos_token
@@ -39,11 +37,12 @@ model = AutoPeftModelForCausalLM.from_pretrained(
39
  device_map="cpu",
40
  torch_dtype=torch.bfloat16,
41
  low_cpu_mem_usage=True,
 
42
  trust_remote_code=True
43
  )
44
- FastLanguageModel.for_inference(model, cpu=True)
45
  model.eval()
46
- log.info("✅ Modelo carregado (modo eval).")
47
 
48
  generation_config = GenerationConfig(
49
  max_new_tokens=128,
@@ -53,113 +52,120 @@ generation_config = GenerationConfig(
53
  use_cache=True,
54
  eos_token_id=tokenizer.eos_token_id,
55
  pad_token_id=tokenizer.eos_token_id,
 
56
  repetition_penalty=1.1,
57
  length_penalty=1.0
58
  )
59
 
60
- # -------- FASTAPI --------
61
  app = FastAPI(title="News Filter JSON API")
62
 
63
  @app.get("/")
64
  def read_root():
65
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
66
 
67
- @app.get("/filter")
68
- def get_filter(
69
- title: str = Query(..., description="News title"),
70
- content: str = Query(..., description="News content")
71
- ):
72
- try:
73
- result = infer_filter(title, content)
74
- try:
75
- return {"result": json.loads(result)}
76
- except json.JSONDecodeError:
77
- return {"result": result, "warning": "Returned as string due to JSON parsing error"}
78
- except HTTPException as he:
79
- raise he
80
- except Exception as e:
81
- log.exception("❌ Erro inesperado:")
82
- raise HTTPException(status_code=500, detail="Internal server error during inference.")
83
-
84
- @app.on_event("startup")
85
- async def warmup():
86
- log.info("🔥 Executando warmup...")
87
- try:
88
- infer_filter("Test title", "Test content")
89
- log.info("✅ Warmup concluído.")
90
- except Exception as e:
91
- log.warning(f"⚠️ Warmup falhou: {e}")
92
-
93
  # -------- INFERÊNCIA --------
94
  def infer_filter(title, content):
95
- messages = [
96
- {
97
- "role": "user",
98
- "content": """Analyze the news title and content, and return the filters in JSON format with the defined fields.
99
-
100
- Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
101
-
102
- Title: "New 'Star Wars' Movie Announced"
103
- Content: "Lucasfilm confirmed a new Star Wars movie set to release in 2026, directed by a rising filmmaker."
104
- """
105
- },
106
- {
107
- "role": "assistant",
108
- "content": '{ "death_related": false, "relevance": "high", "global_interest": true, "entity_type": "movie", "entity_name": "Star Wars", "breaking_news": true, "has_video_content": false }'
109
- },
110
- {
111
- "role": "user",
112
- "content": """Analyze the news title and content, and return the filters in JSON format with the defined fields.
113
-
114
- Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
115
-
116
- Title: "Legendary Musician Carlos Mendes Dies at 78"
117
- Content: "Carlos Mendes, the internationally acclaimed Brazilian guitarist and composer known for blending traditional bossa nova with modern jazz, has died at the age of 78."
118
- """
119
- },
120
- {
121
- "role": "assistant",
122
- "content": '{ "death_related": true, "relevance": "high", "global_interest": true, "entity_type": "person", "entity_name": "Carlos Mendes", "breaking_news": true, "has_video_content": false }'
123
- },
124
- {
125
- "role": "user",
126
- "content": f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
127
-
128
- Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
129
-
130
- Title: "{title}"
131
- Content: "{content}"
132
- """
133
- }
134
- ]
135
-
136
  log.info(f"🧠 Inferência iniciada para: {title}")
137
  start_time = time.time()
138
 
139
- inputs = tokenizer.apply_chat_template(
140
- messages,
141
- tokenize=True,
142
- add_generation_prompt=True,
143
  return_tensors="pt",
144
- ).to("cpu")
 
 
 
 
 
 
 
145
 
146
  with torch.no_grad(), torch.inference_mode():
147
  outputs = model.generate(
148
- input_ids=inputs,
 
149
  generation_config=generation_config,
 
 
 
150
  )
151
 
152
- prompt_text = tokenizer.decode(inputs[0], skip_special_tokens=True)
153
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
154
- generated = full_output[len(prompt_text):].strip()
 
 
 
 
 
 
 
 
155
 
156
- json_str = extract_json(generated)
157
  duration = time.time() - start_time
158
  log.info(f"✅ JSON extraído em {duration:.2f}s")
159
- return json_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  def extract_json(text):
162
- match = re.search(r'\{.*?\}', text, flags=re.DOTALL)
163
  if match:
164
- return match.group(0)
165
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Query, HTTPException
 
 
2
  import torch
3
  import re
4
  import time
5
  import logging
6
  import os
 
 
7
  from transformers import AutoTokenizer, GenerationConfig
8
  from peft import AutoPeftModelForCausalLM
9
+ import gc
 
10
 
11
  # -------- CONFIGURAÇÕES DE OTIMIZAÇÃO --------
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
15
  torch.set_num_threads(2)
16
  torch.set_num_interop_threads(1)
17
 
18
+ # -------- LOGGING CONFIG --------
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
20
  log = logging.getLogger("news-filter")
21
 
22
+ # -------- LOAD MODEL --------
23
  model_name = "habulaj/filterinstruct180"
24
  log.info("🚀 Carregando modelo e tokenizer...")
25
 
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ use_fast=True,
29
+ padding_side="left"
30
+ )
31
 
32
  if tokenizer.pad_token is None:
33
  tokenizer.pad_token = tokenizer.eos_token
 
37
  device_map="cpu",
38
  torch_dtype=torch.bfloat16,
39
  low_cpu_mem_usage=True,
40
+ use_cache=True,
41
  trust_remote_code=True
42
  )
43
+
44
  model.eval()
45
+ log.info("✅ Modelo carregado (eval mode).")
46
 
47
  generation_config = GenerationConfig(
48
  max_new_tokens=128,
 
52
  use_cache=True,
53
  eos_token_id=tokenizer.eos_token_id,
54
  pad_token_id=tokenizer.eos_token_id,
55
+ no_repeat_ngram_size=2,
56
  repetition_penalty=1.1,
57
  length_penalty=1.0
58
  )
59
 
60
+ # -------- FASTAPI INIT --------
61
  app = FastAPI(title="News Filter JSON API")
62
 
63
  @app.get("/")
64
  def read_root():
65
  return {"message": "News Filter JSON API is running!", "docs": "/docs"}
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # -------- INFERÊNCIA --------
68
  def infer_filter(title, content):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  log.info(f"🧠 Inferência iniciada para: {title}")
70
  start_time = time.time()
71
 
72
+ chat_prompt = build_chat_prompt(title, content)
73
+
74
+ inputs = tokenizer(
75
+ chat_prompt,
76
  return_tensors="pt",
77
+ truncation=True,
78
+ max_length=512,
79
+ padding=False,
80
+ add_special_tokens=False
81
+ )
82
+
83
+ input_ids = inputs.input_ids
84
+ attention_mask = inputs.attention_mask
85
 
86
  with torch.no_grad(), torch.inference_mode():
87
  outputs = model.generate(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
  generation_config=generation_config,
91
+ num_return_sequences=1,
92
+ output_scores=False,
93
+ return_dict_in_generate=False
94
  )
95
 
96
+ generated_tokens = outputs[0][len(input_ids[0]):]
97
+ generated = tokenizer.decode(
98
+ generated_tokens,
99
+ skip_special_tokens=True,
100
+ clean_up_tokenization_spaces=True
101
+ )
102
+
103
+ log.info("📤 Resultado gerado:")
104
+ log.info(generated)
105
+
106
+ json_result = extract_json(generated)
107
 
 
108
  duration = time.time() - start_time
109
  log.info(f"✅ JSON extraído em {duration:.2f}s")
110
+
111
+ # Limpeza de memória
112
+ del outputs, generated_tokens, inputs
113
+ gc.collect()
114
+
115
+ if json_result:
116
+ return json_result
117
+ else:
118
+ raise HTTPException(status_code=404, detail="Unable to extract JSON from model output.")
119
+
120
+ def build_chat_prompt(title: str, content: str) -> str:
121
+ return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
122
+ Analyze the news title and content, and return the filters in JSON format with the defined fields.
123
+
124
+ Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
125
+
126
+ Title: "{title}"
127
+ Content: "{content}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
128
 
129
  def extract_json(text):
130
+ match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
131
  if match:
132
+ json_text = match.group(0)
133
+
134
+ # Conversões comuns
135
+ json_text = re.sub(r"'", '"', json_text)
136
+ json_text = re.sub(r'\bTrue\b', 'true', json_text)
137
+ json_text = re.sub(r'\bFalse\b', 'false', json_text)
138
+ json_text = re.sub(r",\s*}", "}", json_text)
139
+ json_text = re.sub(r",\s*]", "]", json_text)
140
+ return json_text.strip()
141
+ return text
142
+
143
+ # -------- API ROUTE --------
144
+ @app.get("/filter")
145
+ def get_filter(
146
+ title: str = Query(..., description="News title"),
147
+ content: str = Query(..., description="News content")
148
+ ):
149
+ try:
150
+ json_output = infer_filter(title, content)
151
+ import json
152
+ try:
153
+ parsed = json.loads(json_output)
154
+ return {"result": parsed}
155
+ except json.JSONDecodeError as e:
156
+ log.error(f"❌ Erro ao parsear JSON: {e}")
157
+ return {"result": json_output, "warning": "JSON returned as string due to parsing error"}
158
+ except HTTPException as e:
159
+ raise e
160
+ except Exception as e:
161
+ log.exception("❌ Erro inesperado:")
162
+ raise HTTPException(status_code=500, detail="Internal server error during inference.")
163
+
164
+ @app.on_event("startup")
165
+ async def warmup():
166
+ log.info("🔥 Executando warmup...")
167
+ try:
168
+ infer_filter("Test title", "Test content")
169
+ log.info("✅ Warmup concluído.")
170
+ except Exception as e:
171
+ log.warning(f"⚠️ Warmup falhou: {e}")