Reality123b commited on
Commit
3d08dbc
·
verified ·
1 Parent(s): b55e187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -299
app.py CHANGED
@@ -1,269 +1,60 @@
1
  import gradio as gr
2
- from pathlib import Path
3
- import os
4
- from huggingface_hub import snapshot_download
5
- from mistral_inference.transformer import Transformer
6
- from mistral_inference.generate import generate
7
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
8
- from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, SystemMessage
9
- from mistral_common.protocol.instruct.request import ChatCompletionRequest
10
 
11
- def download_mistral_model():
12
- """Download Mistral model if not already present."""
13
- print("Checking for Mistral model...")
14
- mistral_models_path = Path.home().joinpath('mistral_models', 'Nemo-Instruct')
15
-
16
- # Check if model files already exist
17
- required_files = ["params.json", "consolidated.safetensors", "tekken.json"]
18
- files_exist = all(
19
- mistral_models_path.joinpath(file).exists()
20
- for file in required_files
21
- )
22
-
23
- if not files_exist:
24
- print("Downloading Mistral model (this may take a while)...")
25
- mistral_models_path.mkdir(parents=True, exist_ok=True)
26
-
27
- snapshot_download(
28
- repo_id="mistralai/Mistral-Nemo-Instruct-2407",
29
- allow_patterns=required_files,
30
- local_dir=mistral_models_path
31
- )
32
- print("Model downloaded successfully!")
33
- else:
34
- print("Mistral model already downloaded.")
35
-
36
- return mistral_models_path
37
-
38
- def setup_mistral():
39
- """Initialize Mistral model and tokenizer."""
40
- mistral_models_path = download_mistral_model()
41
- print("Initializing Mistral model and tokenizer...")
42
- tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
43
- model = Transformer.from_folder(mistral_models_path)
44
- return model, tokenizer
45
-
46
- # Global variables for model and tokenizer
47
- global_model = None
48
- global_tokenizer = None
49
-
50
- def initialize_globals():
51
- """Initialize global model and tokenizer if not already done."""
52
- global global_model, global_tokenizer
53
- if global_model is None or global_tokenizer is None:
54
- global_model, global_tokenizer = setup_mistral()
55
-
56
- def check_custom_responses(message: str) -> str:
57
- """Check for specific patterns and return custom responses."""
58
- message_lower = message.lower()
59
- custom_responses = {
60
- "what is ur name?": "xylaria",
61
- "what is ur Name?": "xylaria",
62
- "what is Ur name?": "xylaria",
63
- "what is Ur Name?": "xylaria",
64
- "What is ur name?": "xylaria",
65
- "What is ur Name?": "xylaria",
66
- "What is Ur name?": "xylaria",
67
- "What is Ur Name?": "xylaria",
68
- "what's ur name?": "xylaria",
69
- "what's ur Name?": "xylaria",
70
- "what's Ur name?": "xylaria",
71
- "what's Ur Name?": "xylaria",
72
- "whats ur name?": "xylaria",
73
- "whats ur Name?": "xylaria",
74
- "whats Ur name?": "xylaria",
75
- "whats Ur Name?": "xylaria",
76
- "what's your name?": "xylaria",
77
- "what's your Name?": "xylaria",
78
- "what's Your name?": "xylaria",
79
- "what's Your Name?": "xylaria",
80
- "Whats ur name?": "xylaria",
81
- "Whats ur Name?": "xylaria",
82
- "Whats Ur name?": "xylaria",
83
- "Whats Ur Name?": "xylaria",
84
- "What Is Your Name?": "xylaria",
85
- "What Is Ur Name?": "xylaria",
86
- "What Is Your Name?": "xylaria",
87
- "What Is Ur Name?": "xylaria",
88
- "what is your name?": "xylaria",
89
- "what is your Name?": "xylaria",
90
- "what is Your name?": "xylaria",
91
- "what is Your Name?": "xylaria",
92
- "how many 'r' is in strawberry?": "3",
93
- "how many 'R' is in strawberry?": "3",
94
- "how many 'r' Is in strawberry?": "3",
95
- "how many 'R' Is in strawberry?": "3",
96
- "How many 'r' is in strawberry?": "3",
97
- "How many 'R' is in strawberry?": "3",
98
- "How Many 'r' Is In Strawberry?": "3",
99
- "How Many 'R' Is In Strawberry?": "3",
100
- "how many r is in strawberry?": "3",
101
- "how many R is in strawberry?": "3",
102
- "how many r Is in strawberry?": "3",
103
- "how many R Is in strawberry?": "3",
104
- "How many r is in strawberry?": "3",
105
- "How many R is in strawberry?": "3",
106
- "How Many R Is In Strawberry?": "3",
107
- "how many 'r' in strawberry?": "3",
108
- "how many r's are in strawberry?": "3",
109
- "how many Rs are in strawberry?": "3",
110
- "How Many R's Are In Strawberry?": "3",
111
- "How Many Rs Are In Strawberry?": "3",
112
- "who is your developer?": "sk md saad amin",
113
- "who is your Developer?": "sk md saad amin",
114
- "who is Your Developer?": "sk md saad amin",
115
- "who is ur developer?": "sk md saad amin",
116
- "who is ur Developer?": "sk md saad amin",
117
- "who is Your Developer?": "sk md saad amin",
118
- "Who is ur developer?": "sk md saad amin",
119
- "Who is ur Developer?": "sk md saad amin",
120
- "who is ur dev?": "sk md saad amin",
121
- "Who is ur dev?": "sk md saad amin",
122
- "who is your dev?": "sk md saad amin",
123
- "Who is your dev?": "sk md saad amin",
124
- "Who's your developer?": "sk md saad amin",
125
- "Who's ur developer?": "sk md saad amin",
126
- "Who Is Your Developer?": "sk md saad amin",
127
- "Who Is Ur Developer?": "sk md saad amin",
128
- "Who Is Your Dev?": "sk md saad amin",
129
- "Who Is Ur Dev?": "sk md saad amin",
130
- "who's your developer?": "sk md saad amin",
131
- "who's ur developer?": "sk md saad amin",
132
- "who is your devloper?": "sk md saad amin",
133
- "who is ur devloper?": "sk md saad amin",
134
- "how many r is in strawberry?": "3",
135
- "how many R is in strawberry?": "3",
136
- "how many r Is in strawberry?": "3",
137
- "how many R Is in strawberry?": "3",
138
- "How many r is in strawberry?": "3",
139
- "How many R is in strawberry?": "3",
140
- "How Many R Is In Strawberry?": "3",
141
- "how many 'r' is in strawberry?": "3",
142
- "how many 'R' is in strawberry?": "3",
143
- "how many 'r' Is in strawberry?": "3",
144
- "how many 'R' Is in strawberry?": "3",
145
- "How many 'r' is in strawberry?": "3",
146
- "How many 'R' is in strawberry?": "3",
147
- "How Many 'r' Is In Strawberry?": "3",
148
- "How Many 'R' Is In Strawberry?": "3",
149
- "how many r's are in strawberry?": "3",
150
- "how many Rs are in strawberry?": "3",
151
- "How Many R's Are In Strawberry?": "3",
152
- "How Many Rs Are In Strawberry?": "3",
153
- "how many Rs's are in strawberry?": "3",
154
- "wat is ur name?": "xylaria",
155
- "wat is ur Name?": "xylaria",
156
- "wut is ur name?": "xylaria",
157
- "wut ur name?": "xylaria",
158
- "wats ur name?": "xylaria",
159
- "wats ur name": "xylaria",
160
- "who's ur dev?": "sk md saad amin",
161
- "who's your dev?": "sk md saad amin",
162
- "who ur dev?": "sk md saad amin",
163
- "who's ur devloper?": "sk md saad amin",
164
- "how many r in strawbary?": "3",
165
- "how many r in strawbary?": "3",
166
- "how many R in strawbary?": "3",
167
- "how many 'r' in strawbary?": "3",
168
- "how many 'R' in strawbary?": "3",
169
- "how many r in strawbry?": "3",
170
- "how many R in strawbry?": "3",
171
- "how many r is in strawbry?": "3",
172
- "how many 'r' is in strawbry?": "3",
173
- "how many 'R' is in strawbry?": "3",
174
- "who is ur dev": "sk md saad amin",
175
- "who is ur devloper": "sk md saad amin",
176
- "what is ur dev": "sk md saad amin",
177
- "who is ur dev?": "sk md saad amin",
178
- "who is ur dev?": "sk md saad amin",
179
- "whats ur dev?": "sk md saad amin",
180
- }
181
-
182
- for pattern, response in custom_responses.items():
183
- if pattern in message_lower:
184
- return response
185
- return None
186
-
187
- def is_image_request(message: str) -> bool:
188
- """Detect if the message is requesting image generation."""
189
- image_triggers = [
190
- "generate an image",
191
- "create an image",
192
- "draw",
193
- "make a picture",
194
- "generate a picture",
195
- "create a picture",
196
- "generate art",
197
- "create art",
198
- "make art",
199
- "visualize",
200
- "show me",
201
- ]
202
- message_lower = message.lower()
203
- return any(trigger in message_lower for trigger in image_triggers)
204
 
205
- def create_mistral_messages(history, system_message, current_message):
206
- """Convert chat history to Mistral message format."""
207
- messages = []
208
-
209
- # Add system message if provided
210
- if system_message:
211
- messages.append(SystemMessage(content=system_message))
212
-
213
- # Add conversation history
 
214
  for user_msg, assistant_msg in history:
215
  if user_msg:
216
- messages.append(UserMessage(content=user_msg))
217
  if assistant_msg:
218
- messages.append(AssistantMessage(content=assistant_msg))
219
-
220
- # Add current message
221
- messages.append(UserMessage(content=current_message))
222
 
223
- return messages
224
 
225
- def respond(message, history, system_message, max_tokens=16343, temperature=0.7, top_p=0.95):
226
- """Main response function using Mistral model."""
227
- # First check for custom responses
228
- custom_response = check_custom_responses(message)
229
- if custom_response:
230
- yield custom_response
231
- return
232
 
233
- # Check for image requests
234
- if is_image_request(message):
235
- yield "Sorry, image generation is not supported in this implementation."
236
- return
237
 
238
- try:
239
- # Initialize global model and tokenizer if needed
240
- initialize_globals()
241
-
242
- # Prepare messages for Mistral
243
- mistral_messages = create_mistral_messages(history, system_message, message)
244
-
245
- # Create chat completion request
246
- completion_request = ChatCompletionRequest(messages=mistral_messages)
247
-
248
- # Encode the request
249
- tokens = global_tokenizer.encode_chat_completion(completion_request).tokens
250
-
251
- # Generate response
252
- out_tokens, _ = generate(
253
- [tokens],
254
- global_model,
255
- max_tokens=max_tokens,
256
- temperature=temperature,
257
- top_p=top_p,
258
- eos_id=global_tokenizer.instruct_tokenizer.tokenizer.eos_id
259
- )
260
-
261
- # Decode and yield response
262
- response = global_tokenizer.decode(out_tokens[0])
263
- yield response
264
 
265
- except Exception as e:
266
- yield f"An error occurred: {str(e)}"
 
 
 
 
 
267
 
268
  # Custom CSS for the Gradio interface
269
  custom_css = """
@@ -274,50 +65,41 @@ body, .gradio-container {
274
  """
275
 
276
  # System message
277
- system_message = """Xylaria (v1.2.9) is an AI assistant developed by Sk Md Saad Amin, designed to provide efficient, practical support in various domains with adaptable communication."""
278
 
279
- def main():
280
- print("Starting Mistral Chat Interface...")
281
- print("Initializing model (this may take a few minutes on first run)...")
282
-
283
- # Initialize model and tokenizer at startup
284
- initialize_globals()
285
-
286
- # Create Gradio interface
287
- demo = gr.ChatInterface(
288
- respond,
289
- additional_inputs=[
290
- gr.Textbox(
291
- value=system_message,
292
- visible=False,
293
- ),
294
- gr.Slider(
295
- minimum=1,
296
- maximum=16343,
297
- value=16343,
298
- step=1,
299
- label="Max new tokens"
300
- ),
301
- gr.Slider(
302
- minimum=0.1,
303
- maximum=4.0,
304
- value=0.7,
305
- step=0.1,
306
- label="Temperature"
307
- ),
308
- gr.Slider(
309
- minimum=0.1,
310
- maximum=1.0,
311
- value=0.95,
312
- step=0.05,
313
- label="Top-p (nucleus sampling)"
314
- ),
315
- ],
316
- css=custom_css
317
- )
318
-
319
- print("Launch successful! Interface is ready to use.")
320
- demo.launch()
321
 
 
322
  if __name__ == "__main__":
323
- main()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
 
 
 
 
 
 
4
 
5
+ # Initialize model and tokenizer
6
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ model_name,
9
+ torch_dtype="auto",
10
+ device_map="auto"
11
+ )
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def generate_response(
15
+ message,
16
+ history: list[tuple[str, str]],
17
+ system_message,
18
+ max_tokens,
19
+ temperature,
20
+ top_p,
21
+ ):
22
+ # Prepare conversation history
23
+ messages = [{"role": "system", "content": system_message}]
24
  for user_msg, assistant_msg in history:
25
  if user_msg:
26
+ messages.append({"role": "user", "content": user_msg})
27
  if assistant_msg:
28
+ messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
29
 
30
+ messages.append({"role": "user", "content": message})
31
 
32
+ # Convert messages to model input format
33
+ text = tokenizer.apply_chat_template(
34
+ messages,
35
+ tokenize=False,
36
+ add_generation_prompt=True
37
+ )
 
38
 
39
+ # Prepare model inputs
40
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
41
 
42
+ # Generate response
43
+ generated_ids = model.generate(
44
+ **model_inputs,
45
+ max_new_tokens=max_tokens,
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ do_sample=True
49
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Extract generated text
52
+ generated_ids = [
53
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
54
+ ]
55
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
+
57
+ yield response
58
 
59
  # Custom CSS for the Gradio interface
60
  custom_css = """
 
65
  """
66
 
67
  # System message
68
+ system_message = """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""
69
 
70
+ # Gradio chat interface
71
+ demo = gr.ChatInterface(
72
+ generate_response,
73
+ additional_inputs=[
74
+ gr.Textbox(
75
+ value=system_message,
76
+ visible=False,
77
+ ),
78
+ gr.Slider(
79
+ minimum=1,
80
+ maximum=2048,
81
+ value=512,
82
+ step=1,
83
+ label="Max new tokens"
84
+ ),
85
+ gr.Slider(
86
+ minimum=0.1,
87
+ maximum=2.0,
88
+ value=0.7,
89
+ step=0.1,
90
+ label="Temperature"
91
+ ),
92
+ gr.Slider(
93
+ minimum=0.1,
94
+ maximum=1.0,
95
+ value=0.95,
96
+ step=0.05,
97
+ label="Top-p (nucleus sampling)"
98
+ ),
99
+ ],
100
+ css=custom_css
101
+ )
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # Launch the demo
104
  if __name__ == "__main__":
105
+ demo.launch()