Spaces:
Sleeping
Sleeping
import gradio as gr | |
from peft import AutoPeftModelForCausalLM | |
from transformers import AutoTokenizer, GenerationConfig | |
import torch | |
import re | |
import json | |
import time | |
import logging | |
import os | |
import gc | |
# Configurações de otimização | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
os.environ["OMP_NUM_THREADS"] = "2" | |
os.environ["MKL_NUM_THREADS"] = "2" | |
torch.set_num_threads(2) | |
torch.set_num_interop_threads(1) | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
log = logging.getLogger("news-filter-gradio") | |
device = "cpu" | |
torch.set_default_device(device) | |
# Carrega modelo e tokenizer | |
print("🚀 Carregando modelo e tokenizer...") | |
log.info("🚀 Carregando modelo e tokenizer...") | |
model = AutoPeftModelForCausalLM.from_pretrained( | |
"habulaj/filterinstruct2", | |
device_map=device, | |
torch_dtype=torch.bfloat16, | |
load_in_4bit=False, | |
low_cpu_mem_usage=True, | |
use_cache=True, | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"habulaj/filterinstruct2", | |
use_fast=True, | |
padding_side="left" | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.eval() | |
log.info("✅ Modelo carregado (eval mode).") | |
tokenizer.chat_template = """{% for message in messages %} | |
{%- if message['role'] == 'user' %} | |
{%- if loop.first %} | |
<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
{{ message['content'] }}<|eot_id|> | |
{%- else %} | |
<|start_header_id|>user<|end_header_id|> | |
{{ message['content'] }}<|eot_id|> | |
{%- endif %} | |
{%- elif message['role'] == 'assistant' %} | |
<|start_header_id|>assistant<|end_header_id|> | |
{{ message['content'] }}<|eot_id|> | |
{%- endif %} | |
{%- endfor %} | |
{%- if add_generation_prompt %} | |
<|start_header_id|>assistant<|end_header_id|> | |
{%- endif %}""" | |
generation_config = GenerationConfig( | |
max_new_tokens=200, | |
temperature=1.0, | |
min_p=0.1, | |
do_sample=True, | |
use_cache=True, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
def extract_json(text): | |
match = re.search(r'\{.*\}', text, flags=re.DOTALL) | |
if match: | |
return match.group(0) | |
return text | |
def analyze_news(title, content): | |
try: | |
log.info(f"🧠 Inferência iniciada para: {title}") | |
start_time = time.time() | |
messages = [ | |
{ | |
"role": "user", | |
"content": f"""Analyze the news title and content, and return the filters in JSON format with the defined fields. | |
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text. | |
Title: "{title}" | |
Content: "{content}" | |
""" | |
} | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(), torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=inputs, | |
generation_config=generation_config, | |
num_return_sequences=1, | |
output_scores=False, | |
return_dict_in_generate=False | |
) | |
prompt_text = tokenizer.decode(inputs[0], skip_special_tokens=False) | |
decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
generated_only = decoded_text[len(prompt_text):].strip() | |
json_result = extract_json(generated_only) | |
duration = time.time() - start_time | |
log.info(f"✅ JSON extraído em {duration:.2f}s") | |
del outputs, inputs | |
gc.collect() | |
try: | |
parsed_json = json.loads(json_result) | |
return json.dumps(parsed_json, indent=2, ensure_ascii=False) | |
except json.JSONDecodeError: | |
return json_result | |
except Exception as e: | |
log.exception("❌ Erro inesperado:") | |
return f"Erro durante a análise: {str(e)}" | |
def warmup_model(): | |
log.info("🔥 Executando warmup...") | |
try: | |
analyze_news("Test title", "Test content") | |
log.info("✅ Warmup concluído.") | |
except Exception as e: | |
log.warning(f"⚠️ Warmup falhou: {e}") | |
def create_interface(): | |
with gr.Blocks(title="Analisador de Notícias", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 📰 Analisador de Notícias") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
title_input = gr.Textbox( | |
label="Título da Notícia", | |
placeholder="Digite o título da notícia...", | |
lines=2 | |
) | |
content_input = gr.Textbox( | |
label="Conteúdo da Notícia", | |
placeholder="Digite o conteúdo da notícia...", | |
lines=6 | |
) | |
analyze_btn = gr.Button("🔍 Analisar Notícia", variant="primary") | |
with gr.Column(scale=1): | |
output = gr.Textbox( | |
label="Resultado JSON", | |
lines=15, | |
max_lines=20, | |
show_copy_button=True | |
) | |
status = gr.Textbox( | |
label="Status", | |
value="Aguardando entrada...", | |
interactive=False | |
) | |
def update_status_and_analyze(title, content): | |
if not title.strip() or not content.strip(): | |
return "❌ Preencha título e conteúdo.", "Erro: Campos obrigatórios." | |
try: | |
result = analyze_news(title, content) | |
return f"✅ Análise concluída!", result | |
except Exception as e: | |
return f"❌ Erro: {str(e)}", f"Erro: {str(e)}" | |
analyze_btn.click( | |
fn=update_status_and_analyze, | |
inputs=[title_input, content_input], | |
outputs=[status, output] | |
) | |
return demo | |
if __name__ == "__main__": | |
warmup_model() | |
print("🚀 Iniciando interface Gradio...") | |
demo = create_interface() | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |