File size: 3,515 Bytes
88bfb08 4f7e40d 8751f54 a474012 665b7ce a474012 9faf370 88bfb08 8751f54 0505899 88bfb08 9917b41 88bfb08 9917b41 8751f54 a474012 665b7ce 8751f54 665b7ce a474012 665b7ce 8751f54 a474012 8751f54 a474012 0df3eed 2db33b9 8751f54 a474012 9917b41 a474012 8751f54 a474012 8751f54 a474012 8751f54 a474012 8751f54 7f85665 88bfb08 8751f54 88bfb08 8751f54 a474012 665b7ce 8751f54 665b7ce 8751f54 665b7ce e5039e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import threading
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
hf_token = os.getenv("Key")
# Configuração do modelo
model_id = "lxcorp/Synap-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# CSS visual
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
background-color: #111 !important;
color: #e0e0e0 !important;
}
textarea, input, button, select {
background-color: transparent !important;
color: #e0e0e0 !important;
border: 1px solid #444 !important;
}
"""
# Controle global de parada
stop_signal = False
def stop_stream():
global stop_signal
stop_signal = True
# Geração com streaming
def generate_response(message, max_tokens, temperature, top_p):
global stop_signal
stop_signal = False
prompt = f"Contexto: {message}\nResposta curta e direta:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for token in streamer:
if stop_signal:
break
full_text += token
yield full_text.strip()
if stop_signal:
return
# Interface Gradio
with gr.Blocks(css=css, theme="NoCrypt/miku") as app:
chatbot = gr.Chatbot(label="Synap - 2B", elem_id="chatbot")
msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
send_btn = gr.Button("Enviar")
stop_btn = gr.Button("Parar")
max_tokens = gr.Slider(64, 1024, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
state = gr.State([]) # histórico apenas visual
def update_chat(message, chat_history):
chat_history = chat_history + [(message, None)] # adiciona só a pergunta
return "", chat_history
def generate_full(chat_history, max_tokens, temperature, top_p):
message = chat_history[-1][0] # última mensagem enviada
visual_history = chat_history[:-1] # remove temporariamente a entrada pendente
full_response = ""
for chunk in generate_response(message, max_tokens, temperature, top_p):
full_response = chunk
yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]
send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
.then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])
stop_btn.click(stop_stream, inputs=[], outputs=[])
app.launch(share=True) |