File size: 3,552 Bytes
81fa74f
 
 
 
 
 
 
 
 
 
 
6108324
81fa74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6108324
81fa74f
 
 
892e094
81fa74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
File: llm.py
Description: Large language model utility functions.
Author: Didier Guillevic
Date: 2025-05-03
"""

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextIteratorStreamer
import threading
import torch
import spaces

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

#
# Load the model: "Qwen/Qwen3-4B"
#
model_id = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto"
)
model = torch.compile(model)
model.eval() # inference mode

# Get end of thinking response token (used to split the response)
end_think_token_id = tokenizer.convert_tokens_to_ids("</think>")

# Output information about the model
def model_info(model):
    # Number of parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Estimated memory usage (in GB)
    param_count = sum(p.numel() for p in model.parameters())
    param_size = param_count * model.dtype.itemsize  # in bytes

    return {
        "dtype": model.dtype,
        "device": model.device,
        "nb_parameters": f"{total_params / 1e6:.2f} M",
        "size": f"{param_size / 1024**3:.2f} GB"
    }

logger.info(f"{model_info(model)}")

#
# Build (text) messages
#
def build_messages(
        message: str,
        history: list[dict]
    ) -> list[dict]:
    """Build messages given message & history from a **text** chat interface.

    Args:
        message: user input
        history: list of dictionaries (with user & assistant messages)
    
    Returns:
        list of messages (to be sent to the model)
    """
    messages = history
    # Add whether the model should think before responding
    # (note that thinking is enabled by default, so we could omit ' /think')
    messages.append({
        'role': 'user',
        'content': message
        #'content': message + (' /think' if thinking else ' /no_think')
    })

    return messages


#
# Stream response
#
@spaces.GPU
@torch.inference_mode()
def stream_response(
        messages: list[dict],
        enable_thinking: bool=False,
        max_new_tokens: int=1_024
    ) -> tuple[str, str]:
    """Stream the model's response to the chat interface.
    
    Args:
        messages: list of messages (to be sent to the model)
        thinking: boolean indicating whether the model should think before responding
    
    Returns:
        tuple of (thinking_response, final_response)
    """
    # apply chat template and get model's inputs
    model_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking
    )
    model_inputs = tokenizer(
        [model_prompt,],
        return_tensors="pt"
    ).to(model.device)

    # get the model's response
    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.6,
        top_p=0.95,
        top_k=20,
        repetition_penalty=1.5,
        min_p=0.0,
        use_cache=True,
    )
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    thinking_response = ""
    final_response = ""
    is_final_response = False

    for text in streamer:
        final_response += text
        yield final_response