File size: 3,705 Bytes
c256c10
00e773a
301eb87
c256c10
 
 
 
 
 
00e773a
c256c10
 
 
 
 
 
 
 
 
 
 
 
 
00e773a
c256c10
00e773a
 
c256c10
00e773a
c256c10
 
 
 
00e773a
c256c10
 
 
 
 
 
 
 
00e773a
c256c10
 
 
 
 
00e773a
 
 
c256c10
 
 
 
 
 
301eb87
c256c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00e773a
 
c256c10
 
 
 
 
 
00e773a
c256c10
 
 
 
00e773a
 
 
 
 
c256c10
00e773a
c256c10
00e773a
 
c256c10
00e773a
 
 
c256c10
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# If you have a HF token in the Space secrets, uncomment below:
# os.environ["HUGGINGFACE_HUB_TOKEN"] = os.getenv("HF_TOKEN", "")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer + model with trust_remote_code, and let Transformers shard/auto‐offload if needed.
tokenizer = AutoTokenizer.from_pretrained(
    "Fastweb/FastwebMIIA-7B",
    use_fast=True,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    "Fastweb/FastwebMIIA-7B",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",            # let HF accelerate/device_map place layers automatically
    trust_remote_code=True
)

model.eval()  # set to eval mode

def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Build a list of messages in the format the model expects, apply any chat template,
    tokenize, generate, and decode. Wrap inference in torch.no_grad() to save memory.
    """
    # 1) Build the “chat” message list
    messages = []
    if system_message:
        messages.append({"role": "system", "content": system_message})

    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})

    messages.append({"role": "user", "content": message})

    # 2) Format via the model’s chat template
    #    Note: many community‐models define `apply_chat_template`.
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs.input_ids.to(DEVICE)
    attention_mask = inputs.attention_mask.to(DEVICE)

    # 3) Inference under no_grad
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    # 4) Skip the prompt tokens and decode only the newly generated tokens
    generated_tokens = outputs[0][input_ids.shape[1]:]
    response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return response


# Build a Gradio ChatInterface; sliders/textbox for system‐prompt and sampling‐params
chat_interface = gr.ChatInterface(
    fn=respond,
    title="FastwebMIIA‐7B Chatbot",
    description="A simple chat demo using Fastweb/FastwebMIIA‐7B",
    # “additional_inputs” become available above the conversation window
    additional_inputs=[
        gr.Textbox(
            value="You are a helpful assistant.",
            label="System message (role: system)"
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    # You can tweak CSS or theme here if you like; omitted for brevity.
)

if __name__ == "__main__":
    # On HF Spaces, you often want `share=False` (default). If you need to expose a public URL, set True.
    chat_interface.launch(server_name="0.0.0.0", server_port=7860)