Didier commited on
Commit
81fa74f
·
verified ·
1 Parent(s): 42b612a

Create llm.py

Browse files
Files changed (1) hide show
  1. llm.py +133 -0
llm.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: llm.py
3
+ Description: Large language model utility functions.
4
+ Author: Didier Guillevic
5
+ Date: 2025-05-03
6
+ """
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from transformers import TextIteratorStreamer
10
+ import threading
11
+ import torch
12
+
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ #
18
+ # Load the model: "Qwen/Qwen3-4B"
19
+ #
20
+ model_id = "Qwen/Qwen3-4B"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ torch_dtype="auto",
25
+ device_map="auto"
26
+ )
27
+ model = torch.compile(model)
28
+ model.eval() # inference mode
29
+
30
+ # Get end of thinking response token (used to split the response)
31
+ end_think_token_id = tokenizer.convert_tokens_to_ids("</think>")
32
+
33
+ # Output information about the model
34
+ def model_info(model):
35
+ # Number of parameters
36
+ total_params = sum(p.numel() for p in model.parameters())
37
+
38
+ # Estimated memory usage (in GB)
39
+ param_count = sum(p.numel() for p in model.parameters())
40
+ param_size = param_count * model.dtype.itemsize # in bytes
41
+
42
+ return {
43
+ "dtype": model.dtype,
44
+ "device": model.device,
45
+ "nb_parameters": f"{total_params / 1e6:.2f} M",
46
+ "size": f"{param_size / 1024**3:.2f} GB"
47
+ }
48
+
49
+ logger.info(f"{model_info(model)}")
50
+
51
+ #
52
+ # Build (text) messages
53
+ #
54
+ def build_messages(
55
+ message: str,
56
+ history: list[dict]
57
+ ) -> list[dict]:
58
+ """Build messages given message & history from a **text** chat interface.
59
+
60
+ Args:
61
+ message: user input
62
+ history: list of dictionaries (with user & assistant messages)
63
+
64
+ Returns:
65
+ list of messages (to be sent to the model)
66
+ """
67
+ messages = history
68
+ # Add whether the model should think before responding
69
+ # (note that thinking is enabled by default, so we could omit ' /think')
70
+ messages.append({
71
+ 'role': 'user',
72
+ 'content': message
73
+ #'content': message + (' /think' if thinking else ' /no_think')
74
+ })
75
+
76
+ return messages
77
+
78
+
79
+ #
80
+ # Stream response
81
+ #
82
+ @torch.inference_mode()
83
+ def stream_response(
84
+ messages: list[dict],
85
+ enable_thinking: bool=True,
86
+ max_new_tokens: int=1_024
87
+ ) -> tuple[str, str]:
88
+ """Stream the model's response to the chat interface.
89
+
90
+ Args:
91
+ messages: list of messages (to be sent to the model)
92
+ thinking: boolean indicating whether the model should think before responding
93
+
94
+ Returns:
95
+ tuple of (thinking_response, final_response)
96
+ """
97
+ # apply chat template and get model's inputs
98
+ model_prompt = tokenizer.apply_chat_template(
99
+ messages,
100
+ tokenize=False,
101
+ add_generation_prompt=True,
102
+ enable_thinking=enable_thinking
103
+ )
104
+ model_inputs = tokenizer(
105
+ [model_prompt,],
106
+ return_tensors="pt"
107
+ ).to(model.device)
108
+
109
+ # get the model's response
110
+ streamer = TextIteratorStreamer(
111
+ tokenizer, skip_prompt=True, skip_special_tokens=True)
112
+ generation_kwargs = dict(
113
+ model_inputs,
114
+ streamer=streamer,
115
+ max_new_tokens=max_new_tokens,
116
+ do_sample=True,
117
+ temperature=0.6,
118
+ top_p=0.95,
119
+ top_k=20,
120
+ repetition_penalty=1.5,
121
+ min_p=0.0,
122
+ use_cache=True,
123
+ )
124
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
125
+ thread.start()
126
+
127
+ thinking_response = ""
128
+ final_response = ""
129
+ is_final_response = False
130
+
131
+ for text in streamer:
132
+ final_response += text
133
+ yield final_response