Reality123b commited on
Commit
83e20b0
·
verified ·
1 Parent(s): 4e0506f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -1,15 +1,31 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
  # Initialize model and tokenizer
6
  model_name = "Qwen/Qwen2.5-3B-Instruct"
 
7
  model = AutoModelForCausalLM.from_pretrained(
8
  model_name,
9
  torch_dtype="auto",
10
  device_map="auto"
11
  )
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_response(
15
  message,
@@ -36,32 +52,37 @@ def generate_response(
36
  add_generation_prompt=True
37
  )
38
 
39
- # Prepare model inputs
40
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
41
-
42
- # Generate response
43
- generated_ids = model.generate(
44
- **model_inputs,
45
- max_new_tokens=max_tokens,
46
- temperature=temperature,
47
- top_p=top_p,
48
- do_sample=True
49
- )
 
 
50
 
51
- # Extract generated text
52
- generated_ids = [
53
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
54
- ]
55
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
-
57
- yield response
58
 
59
- # Custom CSS for the Gradio interface
60
  custom_css = """
61
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
62
  body, .gradio-container {
63
  font-family: 'Inter', sans-serif;
64
  }
 
 
 
 
 
 
 
65
  """
66
 
67
  # System message
@@ -102,4 +123,5 @@ demo = gr.ChatInterface(
102
 
103
  # Launch the demo
104
  if __name__ == "__main__":
 
105
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ import time
5
 
6
  # Initialize model and tokenizer
7
  model_name = "Qwen/Qwen2.5-3B-Instruct"
8
+ print("Loading model and tokenizer...")
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
  device_map="auto"
13
  )
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ print("Model and tokenizer loaded!")
16
+
17
+ def simulate_typing(text, min_chars_per_sec=20, max_chars_per_sec=60):
18
+ """Simulate typing animation with variable speed."""
19
+ full_text = ""
20
+ words = text.split()
21
+ for i, word in enumerate(words):
22
+ full_text += word
23
+ if i < len(words) - 1:
24
+ full_text += " "
25
+ # Vary typing speed between min and max chars per second
26
+ delay = 1 / (min_chars_per_sec + (max_chars_per_sec - min_chars_per_sec) * torch.rand(1).item())
27
+ time.sleep(delay)
28
+ yield full_text
29
 
30
  def generate_response(
31
  message,
 
52
  add_generation_prompt=True
53
  )
54
 
55
+ # Prepare model inputs and generate in one go
56
+ with torch.inference_mode():
57
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
58
+ generated_ids = model.generate(
59
+ **model_inputs,
60
+ max_new_tokens=max_tokens,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ do_sample=True,
64
+ pad_token_id=tokenizer.eos_token_id
65
+ )
66
+ generated_ids = generated_ids[0, len(model_inputs.input_ids[0]):]
67
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
68
 
69
+ # Return response with typing animation
70
+ for partial_response in simulate_typing(response):
71
+ yield partial_response
 
 
 
 
72
 
73
+ # Custom CSS with typing cursor animation
74
  custom_css = """
75
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
76
  body, .gradio-container {
77
  font-family: 'Inter', sans-serif;
78
  }
79
+ .typing-cursor::after {
80
+ content: '|';
81
+ animation: blink 1s step-start infinite;
82
+ }
83
+ @keyframes blink {
84
+ 50% { opacity: 0; }
85
+ }
86
  """
87
 
88
  # System message
 
123
 
124
  # Launch the demo
125
  if __name__ == "__main__":
126
+ demo.queue() # Enable queuing for better handling of multiple requests
127
  demo.launch()