File size: 3,515 Bytes
88bfb08
4f7e40d
8751f54
a474012
665b7ce
a474012
9faf370
88bfb08
 
8751f54
0505899
88bfb08
9917b41
 
88bfb08
 
9917b41
8751f54
 
a474012
665b7ce
8751f54
665b7ce
 
a474012
 
 
665b7ce
 
 
 
 
 
 
 
 
 
 
8751f54
a474012
 
 
 
 
 
8751f54
 
a474012
 
 
0df3eed
2db33b9
8751f54
a474012
 
9917b41
a474012
8751f54
 
a474012
 
 
 
 
8751f54
a474012
 
 
 
 
8751f54
a474012
 
 
8751f54
 
 
 
 
 
 
7f85665
88bfb08
8751f54
 
 
 
88bfb08
8751f54
 
 
 
 
 
 
a474012
665b7ce
8751f54
 
 
 
 
 
 
 
 
 
 
665b7ce
8751f54
665b7ce
e5039e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import threading
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch

hf_token = os.getenv("Key")

# Configuração do modelo
model_id = "lxcorp/Synap-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# CSS visual
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
    font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
    background-color: #111 !important;
    color: #e0e0e0 !important;
}
textarea, input, button, select {
    background-color: transparent !important;
    color: #e0e0e0 !important;
    border: 1px solid #444 !important;
}
"""

# Controle global de parada
stop_signal = False

def stop_stream():
    global stop_signal
    stop_signal = True

# Geração com streaming
def generate_response(message, max_tokens, temperature, top_p):
    global stop_signal
    stop_signal = False

    prompt = f"Contexto: {message}\nResposta curta e direta:"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    full_text = ""
    for token in streamer:
        if stop_signal:
            break
        full_text += token
        yield full_text.strip()

    if stop_signal:
        return

# Interface Gradio
with gr.Blocks(css=css, theme="NoCrypt/miku") as app:
    chatbot = gr.Chatbot(label="Synap - 2B", elem_id="chatbot")
    msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
    send_btn = gr.Button("Enviar")
    stop_btn = gr.Button("Parar")

    max_tokens = gr.Slider(64, 1024, value=128, step=1, label="Max Tokens")
    temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")

    state = gr.State([])  # histórico apenas visual

    def update_chat(message, chat_history):
        chat_history = chat_history + [(message, None)]  # adiciona só a pergunta
        return "", chat_history

    def generate_full(chat_history, max_tokens, temperature, top_p):
        message = chat_history[-1][0]  # última mensagem enviada
        visual_history = chat_history[:-1]  # remove temporariamente a entrada pendente

        full_response = ""
        for chunk in generate_response(message, max_tokens, temperature, top_p):
            full_response = chunk
            yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]

    send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
        .then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])

    stop_btn.click(stop_stream, inputs=[], outputs=[])

app.launch(share=True)