filtergradio / app.py
habulaj's picture
Update app.py
b6db17e verified
raw
history blame
11.6 kB
import gradio as gr
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, GenerationConfig
import torch
import re
import json
import time
import logging
import os
import gc
from typing import Dict, Any, Optional, List, Tuple
import psutil
from contextlib import contextmanager
num_cores = psutil.cpu_count(logical=False)
num_threads = min(4, num_cores)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
os.environ["OPENBLAS_NUM_THREADS"] = str(num_threads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(num_threads)
os.environ["NUMEXPR_NUM_THREADS"] = str(num_threads)
torch.set_num_threads(num_threads)
torch.set_num_interop_threads(1)
torch.backends.mkl.enabled = True
torch.backends.mkldnn.enabled = True
torch.backends.quantized.engine = 'qnnpack'
torch.cuda.empty_cache = lambda: None
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("news-filter-optimized")
device = "cpu"
torch.set_default_device(device)
@contextmanager
def memory_efficient_context():
try:
gc.collect()
yield
finally:
gc.collect()
class OptimizedTokenizerWrapper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self._template_cache = {}
def apply_chat_template(self, messages, **kwargs):
content = messages[0]['content'] if messages else ""
key = hash(content[:100])
if key not in self._template_cache:
result = self.tokenizer.apply_chat_template(messages, **kwargs)
if len(self._template_cache) > 100:
self._template_cache.clear()
self._template_cache[key] = result
return self._template_cache[key]
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def __getattr__(self, name):
return getattr(self.tokenizer, name)
print("🚀 Carregando modelo...")
log.info("🚀 Carregando modelo...")
model_config = {
"device_map": device,
"torch_dtype": torch.float16,
"low_cpu_mem_usage": True,
"use_cache": True,
"trust_remote_code": True,
"attn_implementation": "eager",
}
model = AutoPeftModelForCausalLM.from_pretrained(
"habulaj/filterinstruct180",
**model_config
)
tokenizer = AutoTokenizer.from_pretrained(
"habulaj/filterinstruct180",
use_fast=True,
padding_side="left",
model_max_length=1024,
clean_up_tokenization_spaces=False,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = OptimizedTokenizerWrapper(tokenizer)
model.eval()
for param in model.parameters():
param.requires_grad = False
try:
model = torch.compile(model, mode="reduce-overhead")
log.info("✅ Modelo compilado")
except Exception as e:
log.warning(f"⚠️ Torch compile não disponível: {e}")
if hasattr(model, 'fuse_linear_layers'):
model.fuse_linear_layers()
log.info("✅ Modelo carregado")
tokenizer.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'user' %}{% if loop.first %}<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% else %}<|start_header_id|>user<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% endif %}{% elif message['role'] == 'assistant' %}<|start_header_id|>assistant<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>
{% endif %}"""
generation_config = GenerationConfig(
max_new_tokens=150,
temperature=0.8,
do_sample=False,
use_cache=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
num_beams=1,
early_stopping=True,
)
def extract_json_optimized(text: str) -> str:
if not hasattr(extract_json_optimized, 'pattern'):
extract_json_optimized.pattern = re.compile(r'\{.*?\}', re.DOTALL)
match = extract_json_optimized.pattern.search(text)
return match.group(0) if match else text
def preprocess_input_optimized(title: str, content: str) -> List[Dict[str, str]]:
max_title_length = 100
max_content_length = 500
title = title[:max_title_length] if len(title) > max_title_length else title
content = content[:max_content_length] if len(content) > max_content_length else content
return [{
"role": "user",
"content": f"""Analyze the news title and content, and return the filters in JSON format with the defined fields.
Please respond ONLY with the JSON filter, do NOT add any explanations, system messages, or extra text.
Title: "{title}"
Content: "{content}"
"""
}]
def analyze_news_optimized(title: str, content: str) -> str:
try:
with memory_efficient_context():
start_time = time.time()
messages = preprocess_input_optimized(title, content)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=1024,
)
with torch.no_grad(), torch.inference_mode():
with torch.autocast(device_type='cpu', dtype=torch.float16):
outputs = model.generate(
inputs,
generation_config=generation_config,
num_return_sequences=1,
output_scores=False,
output_hidden_states=False,
output_attentions=False,
return_dict_in_generate=False,
use_cache=True,
do_sample=False,
)
generated_tokens = outputs[0][inputs.shape[1]:]
generated_text = tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
json_result = extract_json_optimized(generated_text)
duration = time.time() - start_time
log.info(f"✅ Análise concluída em {duration:.2f}s")
del outputs, inputs, generated_tokens
try:
parsed_json = json.loads(json_result)
return json.dumps(parsed_json, indent=2, ensure_ascii=False)
except json.JSONDecodeError:
return json_result
except Exception as e:
log.exception("❌ Erro durante análise:")
return f"Erro durante a análise: {str(e)}"
def warmup_optimized():
log.info("🔥 Executando warmup...")
try:
for i in range(3):
result = analyze_news_optimized(f"Test title {i}", f"Test content {i}")
log.info(f"Warmup {i+1}/3 concluído")
gc.collect()
log.info("✅ Warmup concluído")
except Exception as e:
log.warning(f"⚠️ Warmup falhou: {e}")
def create_optimized_interface():
with gr.Blocks(
title="Analisador de Notícias - Ultra Otimizado",
theme=gr.themes.Monochrome(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.performance-info {
background: #f8f9fa;
border-left: 4px solid #007bff;
padding: 15px;
margin: 10px 0;
}
"""
) as demo:
gr.Markdown("# 🚀 Analisador de Notícias - Ultra Otimizado")
with gr.Row():
with gr.Column(scale=1):
title_input = gr.Textbox(
label="Título da Notícia",
placeholder="Ex: Legendary Musician Carlos Mendes Dies at 78",
max_lines=3
)
content_input = gr.Textbox(
label="Conteúdo da Notícia",
placeholder="Ex: Carlos Mendes, the internationally acclaimed Brazilian guitarist...",
max_lines=6
)
analyze_btn = gr.Button("⚡ Analisar Notícia", variant="primary")
with gr.Row():
example_btn1 = gr.Button("📻 Exemplo 1", size="sm")
example_btn2 = gr.Button("⚽ Exemplo 2", size="sm")
example_btn3 = gr.Button("💼 Exemplo 3", size="sm")
with gr.Column(scale=1):
output = gr.Textbox(
label="Resultado JSON",
lines=15,
max_lines=20,
show_copy_button=True
)
status = gr.Textbox(
label="Status",
value="⚡ Pronto para análise",
interactive=False
)
def analyze_with_status(title: str, content: str) -> Tuple[str, str]:
if not title.strip() or not content.strip():
return "❌ Preencha todos os campos", "Erro: Campos obrigatórios não preenchidos"
try:
start_time = time.time()
result = analyze_news_optimized(title, content)
duration = time.time() - start_time
return f"✅ Análise concluída em {duration:.2f}s", result
except Exception as e:
return f"❌ Erro: {str(e)}", f"Erro: {str(e)}"
examples = [
("Legendary Musician Carlos Mendes Dies at 78", "Carlos Mendes, the internationally acclaimed Brazilian guitarist and composer known for blending traditional bossa nova with modern jazz, has died at the age of 78."),
("Brazil Defeats Argentina 2-1 in Copa America Final", "In a thrilling match at the Maracana Stadium, Brazil secured victory over Argentina with goals from Neymar and Vinicius Jr. The match was watched by over 200 million viewers worldwide."),
("Tech Giant Announces Major Layoffs Affecting 10,000 Employees", "The technology company announced significant workforce reductions citing economic uncertainty and changing market conditions. The layoffs will affect multiple departments across different regions.")
]
analyze_btn.click(
fn=analyze_with_status,
inputs=[title_input, content_input],
outputs=[status, output]
)
example_btn1.click(
fn=lambda: examples[0],
outputs=[title_input, content_input]
)
example_btn2.click(
fn=lambda: examples[1],
outputs=[title_input, content_input]
)
example_btn3.click(
fn=lambda: examples[2],
outputs=[title_input, content_input]
)
return demo
if __name__ == "__main__":
warmup_optimized()
print("🚀 Iniciando interface...")
demo = create_optimized_interface()
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
max_threads=num_threads,
show_api=False,
)