File size: 4,693 Bytes
82b5e45
 
0fb72b8
 
 
 
 
 
82b5e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fb72b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b5e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fb72b8
 
82b5e45
 
 
 
 
0fb72b8
82b5e45
 
 
 
0fb72b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import threading
import gradio as gr
from huggingface_hub import InferenceClient

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
MODEL_ID = os.getenv("MODEL_ID", "tianzhechu/BookQA-7B-Instruct")
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "Qwen/Qwen2.5-0.5B-Instruct")  # Optional: tokenizer repo to use locally
USE_LOCAL_TRANSFORMERS = bool(TOKENIZER_ID) or os.getenv("USE_LOCAL_TRANSFORMERS") == "1"

# Remote inference (default)
client = None if USE_LOCAL_TRANSFORMERS else InferenceClient(MODEL_ID)

# Lazy-loaded local model/tokenizer when TOKENIZER_ID is provided
local_model = None
local_tokenizer = None


def _ensure_local_model_loaded():
    global local_model, local_tokenizer
    if local_model is not None and local_tokenizer is not None:
        return
    from transformers import AutoModelForCausalLM, AutoTokenizer

    if not TOKENIZER_ID:
        raise RuntimeError(
            "Local transformers backend requires TOKENIZER_ID to be set to a tokenizer repo."
        )
    local_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True)
    local_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    if not USE_LOCAL_TRANSFORMERS:
        for message in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message.choices[0].delta.content
            if token:
                response += token
            yield response
        return

    # Local generation using transformers with an alternate tokenizer
    _ensure_local_model_loaded()

    try:
        from transformers import TextIteratorStreamer
    except Exception as e:
        raise RuntimeError(
            "transformers TextIteratorStreamer is required for local streaming; ensure transformers is installed."
        ) from e

    # Use chat template if available; otherwise fall back to a simple concatenation
    try:
        prompt_text = local_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    except Exception:
        convo_parts = []
        for m in messages:
            role = m.get("role", "user")
            content = m.get("content", "")
            if role == "system":
                convo_parts.append(f"<system>\n{content}\n</system>")
            elif role == "assistant":
                convo_parts.append(f"<assistant>\n{content}\n</assistant>")
            else:
                convo_parts.append(f"<user>\n{content}\n</user>")
        prompt_text = "\n".join(convo_parts) + "\n<assistant>\n"

    inputs = local_tokenizer(prompt_text, return_tensors="pt")

    streamer = TextIteratorStreamer(
        local_tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        inputs=inputs.input_ids,
        attention_mask=inputs.get("attention_mask"),
        max_new_tokens=max_tokens,
        do_sample=temperature > 0,
        temperature=temperature,
        top_p=top_p,
        streamer=streamer,
    )

    thread = threading.Thread(target=local_model.generate, kwargs=generate_kwargs)
    thread.start()

    for new_text in streamer:
        if new_text:
            response += new_text
            yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()