Spestly commited on
Commit
3a04e30
Β·
verified Β·
1 Parent(s): 77246c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -78
app.py CHANGED
@@ -2,67 +2,34 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import time
 
5
 
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
- MODELS = {
9
- "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
- "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
11
- "Athena-R3 7B": "Spestly/Athena-R3-7B",
12
- "Athena-3 3B": "Spestly/Athena-3-3B",
13
- "Athena-3 7B": "Spestly/Athena-3-7B",
14
- "Athena-3 14B": "Spestly/Athena-3-14B",
15
- "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
16
- "Athena-1 3B": "Spestly/Athena-1-3B",
17
- "Athena-1 7B": "Spestly/Athena-1-7B"
18
- }
19
-
20
- loaded_models = {}
21
- loaded_tokenizers = {}
22
-
23
- def load_model(model_name):
24
- if model_name in loaded_models:
25
- return loaded_models[model_name], loaded_tokenizers[model_name]
26
-
27
- model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
28
- print(f"πŸš€ Loading {model_id} on {device}...")
29
  start_time = time.time()
30
-
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
- torch_dtype=torch.bfloat16,
35
- device_map=None
 
36
  )
37
- model.to(device)
38
- model.eval()
39
-
40
  load_time = time.time() - start_time
41
- print(f"βœ… Model loaded in {load_time:.2f}s, GPU mem: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
42
-
43
- loaded_models[model_name] = model
44
- loaded_tokenizers[model_name] = tokenizer
45
  return model, tokenizer
46
 
47
- def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
48
- if conversation is None:
49
- conversation = []
50
- model, tokenizer = load_model(model_name)
51
-
52
- # Append user message to conversation
53
- conversation.append(("User", user_message))
54
-
55
- # Build prompt from conversation history (simple concatenation)
56
- prompt = ""
57
- for speaker, text in conversation:
58
- if speaker == "User":
59
- prompt += f"User: {text}\n"
60
- else:
61
- prompt += f"Athena: {text}\n"
62
- prompt += "Athena:"
63
-
64
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
65
-
66
  start_time = time.time()
67
  with torch.no_grad():
68
  outputs = model.generate(
@@ -71,56 +38,152 @@ def chatbot(conversation, user_message, model_name, max_length=512, temperature=
71
  temperature=temperature,
72
  do_sample=True,
73
  top_p=0.9,
74
- pad_token_id=tokenizer.eos_token_id
 
75
  )
 
76
  generation_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- output_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True).strip()
79
-
80
- conversation.append(("Athena", output_text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- stats = f"⚑ Generated in {generation_time:.2f}s | GPU mem: {torch.cuda.memory_allocated()/1e9:.2f} GB | Temp: {temperature}"
 
83
 
84
- return conversation, "", stats
 
 
 
 
 
 
 
 
 
 
85
 
86
- with gr.Blocks(title="Athena Playground Chat") as demo:
 
87
  gr.Markdown("# πŸš€ Athena Playground Chat")
88
-
 
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
  model_choice = gr.Dropdown(
92
- label="Model",
93
  choices=list(MODELS.keys()),
94
- value="Athena-R3X 8B"
 
95
  )
96
- max_length = gr.Slider(32, 4096, value=512, label="Max Tokens")
97
- temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity")
98
- clear_btn = gr.Button("Clear Chat")
99
-
 
 
 
 
 
 
 
 
100
  with gr.Column(scale=3):
101
- chat_history = gr.Chatbot(elem_id="chatbot").style(height=600)
 
 
 
 
102
  user_input = gr.Textbox(
103
  placeholder="Ask Athena anything...",
104
  label="Your message",
105
- lines=2
 
106
  )
107
- submit_btn = gr.Button("Send")
108
-
109
- def clear_chat():
110
- return [], "", ""
111
-
 
 
 
 
 
112
  submit_btn.click(
113
  chatbot,
114
  inputs=[chat_history, user_input, model_choice, max_length, temperature],
115
- outputs=[chat_history, user_input, gr.Textbox(label="Stats")],
116
- queue=True
117
  )
118
-
 
 
 
 
 
 
119
  clear_btn.click(
120
  clear_chat,
121
  inputs=[],
122
- outputs=[chat_history, user_input, gr.Textbox(label="Stats")]
123
  )
124
 
125
- if __name__ == "__main__":
126
- demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import time
5
+ import spaces
6
 
7
+ # ZeroGPU decorator for GPU-intensive functions
8
+ @spaces.GPU
9
+ def load_model_gpu(model_id):
10
+ """Load model on ZeroGPU"""
11
+ print(f"πŸš€ Loading {model_id} on ZeroGPU...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  start_time = time.time()
13
+
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_id,
17
+ torch_dtype=torch.float16, # Use float16 for better memory efficiency
18
+ device_map="auto",
19
+ trust_remote_code=True
20
  )
21
+
 
 
22
  load_time = time.time() - start_time
23
+ print(f"βœ… Model loaded in {load_time:.2f}s")
24
+
 
 
25
  return model, tokenizer
26
 
27
+ @spaces.GPU
28
+ def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7):
29
+ """Generate response using ZeroGPU"""
30
+ device = next(model.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
32
+
33
  start_time = time.time()
34
  with torch.no_grad():
35
  outputs = model.generate(
 
38
  temperature=temperature,
39
  do_sample=True,
40
  top_p=0.9,
41
+ pad_token_id=tokenizer.eos_token_id,
42
+ eos_token_id=tokenizer.eos_token_id
43
  )
44
+
45
  generation_time = time.time() - start_time
46
+ output_text = tokenizer.decode(
47
+ outputs[0][inputs['input_ids'].shape[-1]:],
48
+ skip_special_tokens=True
49
+ ).strip()
50
+
51
+ return output_text, generation_time
52
+
53
+ # Model configurations
54
+ MODELS = {
55
+ "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
56
+ "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
57
+ "Athena-R3 7B": "Spestly/Athena-R3-7B",
58
+ "Athena-3 3B": "Spestly/Athena-3-3B",
59
+ "Athena-3 7B": "Spestly/Athena-3-7B",
60
+ "Athena-3 14B": "Spestly/Athena-3-14B",
61
+ "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
62
+ "Athena-1 3B": "Spestly/Athena-1-3B",
63
+ "Athena-1 7B": "Spestly/Athena-1-7B"
64
+ }
65
 
66
+ def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
67
+ if not user_message.strip():
68
+ return conversation, "", "Please enter a message"
69
+
70
+ if conversation is None:
71
+ conversation = []
72
+
73
+ # Get model ID
74
+ model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
75
+
76
+ try:
77
+ # Load model and tokenizer using ZeroGPU
78
+ model, tokenizer = load_model_gpu(model_id)
79
+
80
+ # Append user message to conversation
81
+ conversation.append([user_message, ""])
82
+
83
+ # Build prompt from conversation history
84
+ prompt = ""
85
+ for user_msg, assistant_msg in conversation[:-1]: # Exclude the current message
86
+ prompt += f"User: {user_msg}\nAthena: {assistant_msg}\n"
87
+ prompt += f"User: {user_message}\nAthena:"
88
+
89
+ # Generate response using ZeroGPU
90
+ output_text, generation_time = generate_response(
91
+ model, tokenizer, prompt, max_length, temperature
92
+ )
93
+
94
+ # Update the last conversation entry with the response
95
+ conversation[-1][1] = output_text
96
+
97
+ stats = f"⚑ Generated in {generation_time:.2f}s | Model: {model_name} | Temp: {temperature}"
98
+
99
+ return conversation, "", stats
100
+
101
+ except Exception as e:
102
+ error_msg = f"Error: {str(e)}"
103
+ if conversation:
104
+ conversation[-1][1] = error_msg
105
+ else:
106
+ conversation = [[user_message, error_msg]]
107
+ return conversation, "", f"❌ Error occurred: {str(e)}"
108
 
109
+ def clear_chat():
110
+ return [], "", ""
111
 
112
+ # CSS for better styling
113
+ css = """
114
+ #chatbot {
115
+ height: 600px;
116
+ }
117
+ .message {
118
+ padding: 10px;
119
+ margin: 5px;
120
+ border-radius: 10px;
121
+ }
122
+ """
123
 
124
+ # Create Gradio interface
125
+ with gr.Blocks(title="Athena Playground Chat", css=css) as demo:
126
  gr.Markdown("# πŸš€ Athena Playground Chat")
127
+ gr.Markdown("*Powered by HuggingFace ZeroGPU*")
128
+
129
  with gr.Row():
130
  with gr.Column(scale=1):
131
  model_choice = gr.Dropdown(
132
+ label="πŸ“± Model",
133
  choices=list(MODELS.keys()),
134
+ value="Athena-R3X 8B",
135
+ info="Select which Athena model to use"
136
  )
137
+ max_length = gr.Slider(
138
+ 32, 2048, value=512,
139
+ label="πŸ“ Max Tokens",
140
+ info="Maximum number of tokens to generate"
141
+ )
142
+ temperature = gr.Slider(
143
+ 0.1, 2.0, value=0.7,
144
+ label="🎨 Creativity",
145
+ info="Higher values = more creative responses"
146
+ )
147
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
148
+
149
  with gr.Column(scale=3):
150
+ chat_history = gr.Chatbot(
151
+ elem_id="chatbot",
152
+ show_label=False,
153
+ avatar_images=["πŸ‘€", "πŸ€–"]
154
+ )
155
  user_input = gr.Textbox(
156
  placeholder="Ask Athena anything...",
157
  label="Your message",
158
+ lines=2,
159
+ max_lines=10
160
  )
161
+ with gr.Row():
162
+ submit_btn = gr.Button("πŸ“€ Send", variant="primary")
163
+ stats_output = gr.Textbox(
164
+ label="Stats",
165
+ interactive=False,
166
+ show_label=False,
167
+ placeholder="Stats will appear here..."
168
+ )
169
+
170
+ # Event handlers
171
  submit_btn.click(
172
  chatbot,
173
  inputs=[chat_history, user_input, model_choice, max_length, temperature],
174
+ outputs=[chat_history, user_input, stats_output]
 
175
  )
176
+
177
+ user_input.submit(
178
+ chatbot,
179
+ inputs=[chat_history, user_input, model_choice, max_length, temperature],
180
+ outputs=[chat_history, user_input, stats_output]
181
+ )
182
+
183
  clear_btn.click(
184
  clear_chat,
185
  inputs=[],
186
+ outputs=[chat_history, user_input, stats_output]
187
  )
188
 
189
+ if __name__ == "__main