Spestly commited on
Commit
739239b
·
verified ·
1 Parent(s): 90de0bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -128
app.py CHANGED
@@ -4,154 +4,69 @@ import torch
4
  import time
5
  import spaces
6
 
7
- # Model configurations
8
  MODELS = {
9
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
11
- "Athena-R3 7B": "Spestly/Athena-R3-7B",
12
- "Athena-3 3B": "Spestly/Athena-3-3B",
13
- "Athena-3 7B": "Spestly/Athena-3-7B",
14
- "Athena-3 14B": "Spestly/Athena-3-14B",
15
- "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
16
- "Athena-1 3B": "Spestly/Athena-1-3B",
17
- "Athena-1 7B": "Spestly/Athena-1-7B"
18
  }
19
 
20
  @spaces.GPU
21
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
22
- """Generate response using ZeroGPU - all CUDA operations happen here"""
23
- print(f"🚀 Loading {model_id}...")
24
- start_time = time.time()
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
26
- if tokenizer.pad_token is None:
27
- tokenizer.pad_token = tokenizer.eos_token
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- torch_dtype=torch.float16,
31
- device_map="auto",
32
- trust_remote_code=True
33
- )
34
- load_time = time.time() - start_time
35
- print(f"✅ Model loaded in {load_time:.2f}s")
36
-
37
- # Build messages in proper chat format (OpenAI-style messages)
38
- messages = []
39
- system_prompt = (
40
- "You are Athena, a helpful, harmless, and honest AI assistant. "
41
- "You provide clear, accurate, and concise responses to user questions. "
42
- "You are knowledgeable across many domains and always aim to be respectful and helpful. "
43
- "You are finetuned by Aayan Mishra"
44
- )
45
- messages.append({"role": "system", "content": system_prompt})
46
-
47
- # Add conversation history (OpenAI-style)
48
- for msg in conversation:
49
- if msg["role"] in ("user", "assistant"):
50
- messages.append({"role": msg["role"], "content": msg["content"]})
51
-
52
- # Add current user message
53
- messages.append({"role": "user", "content": user_message})
54
-
55
- prompt = tokenizer.apply_chat_template(
56
- messages,
57
- tokenize=False,
58
- add_generation_prompt=True
59
- )
60
- inputs = tokenizer(prompt, return_tensors="pt")
61
- device = next(model.parameters()).device
62
- inputs = {k: v.to(device) for k, v in inputs.items()}
63
- generation_start = time.time()
64
- with torch.no_grad():
65
- outputs = model.generate(
66
- **inputs,
67
- max_new_tokens=max_length,
68
- temperature=temperature,
69
- do_sample=True,
70
- top_p=0.9,
71
- pad_token_id=tokenizer.eos_token_id,
72
- eos_token_id=tokenizer.eos_token_id
73
- )
74
- generation_time = time.time() - generation_start
75
- response = tokenizer.decode(
76
- outputs[0][inputs['input_ids'].shape[-1]:],
77
- skip_special_tokens=True
78
- ).strip()
79
- return response, load_time, generation_time
80
 
81
- def respond(message, history, model_name, max_length, temperature):
82
- """Main function for ChatInterface - simplified signature"""
83
  if not message.strip():
84
- return "Please enter a message"
85
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
86
  try:
87
- response, load_time, generation_time = generate_response(
88
- model_id, history, message, max_length, temperature
89
- )
90
- return response
91
  except Exception as e:
92
- return f"Error: {str(e)}"
93
-
94
- css = """
95
- .message {
96
- padding: 10px;
97
- margin: 5px;
98
- border-radius: 10px;
99
- }
100
- """
101
-
102
- theme = gr.themes.Monochrome()
103
 
104
- with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
105
  gr.Markdown("# 🚀 Athena Playground Chat")
106
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
107
 
108
- # --- Create config controls first ---
109
- model_choice = gr.Dropdown(
110
- label="📱 Model",
111
- choices=list(MODELS.keys()),
112
- value="Athena-R3X 4B",
113
- info="Select which Athena model to use"
114
- )
115
- max_length = gr.Slider(
116
- 32, 2048, value=512,
117
- label="📝 Max Tokens",
118
- info="Maximum number of tokens to generate"
119
- )
120
- temperature = gr.Slider(
121
- 0.1, 2.0, value=0.7,
122
- label="🎨 Creativity",
123
- info="Higher values = more creative responses"
124
- )
125
 
126
- # --- Main chat interface ---
127
- chat_interface = gr.ChatInterface(
128
- fn=respond,
129
- additional_inputs=[model_choice, max_length, temperature],
130
- title="Chat with Athena",
131
- description="Ask Athena anything!",
132
- theme="soft",
133
- examples=[
134
- ["Hello! How are you?", "Athena-R3X 8B", 512, 0.7],
135
- ["What can you help me with?", "Athena-R3X 8B", 512, 0.7],
136
- ["Tell me about artificial intelligence", "Athena-R3X 8B", 512, 0.7],
137
- ["Write a short poem about space", "Athena-R3X 8B", 512, 0.7]
138
- ],
139
- cache_examples=False,
140
- chatbot=gr.Chatbot(
141
- height=500,
142
- placeholder="Start chatting with Athena...",
143
- show_share_button=False,
144
- type="messages"
145
- ),
146
- type="messages"
147
- )
148
 
149
- # --- Configuration controls at the bottom ---
150
  gr.Markdown("### ⚙️ Model & Generation Settings")
151
  with gr.Row():
152
- model_choice.render()
153
- max_length.render()
154
- temperature.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if __name__ == "__main__":
157
  demo.launch()
 
4
  import time
5
  import spaces
6
 
 
7
  MODELS = {
8
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
9
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
10
+ # ... other models ...
 
 
 
 
 
 
11
  }
12
 
13
  @spaces.GPU
14
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
15
+ # [Same as your function]
16
+ # ... code omitted for brevity ...
17
+ pass # Use your original code here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def respond(history, message, model_name, max_length, temperature):
 
20
  if not message.strip():
21
+ return history + [["user", message], ["assistant", "Please enter a message"]], ""
22
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
23
  try:
24
+ response, _, _ = generate_response(model_id, history, message, max_length, temperature)
25
+ history = history + [["user", message], ["assistant", response]]
26
+ return history, ""
 
27
  except Exception as e:
28
+ history = history + [["user", message], ["assistant", f"Error: {str(e)}"]]
29
+ return history, ""
 
 
 
 
 
 
 
 
 
30
 
31
+ with gr.Blocks() as demo:
32
  gr.Markdown("# 🚀 Athena Playground Chat")
33
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
34
 
35
+ chatbot = gr.Chatbot(height=500)
36
+ state = gr.State([]) # chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ with gr.Row():
39
+ user_input = gr.Textbox(label="Your message", scale=8)
40
+ send_btn = gr.Button(value="Send", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Place settings at the bottom!
43
  gr.Markdown("### ⚙️ Model & Generation Settings")
44
  with gr.Row():
45
+ model_choice = gr.Dropdown(
46
+ label="📱 Model",
47
+ choices=list(MODELS.keys()),
48
+ value="Athena-R3X 4B",
49
+ info="Select which Athena model to use"
50
+ )
51
+ max_length = gr.Slider(
52
+ 32, 2048, value=512,
53
+ label="📝 Max Tokens",
54
+ info="Maximum number of tokens to generate"
55
+ )
56
+ temperature = gr.Slider(
57
+ 0.1, 2.0, value=0.7,
58
+ label="🎨 Creativity",
59
+ info="Higher values = more creative responses"
60
+ )
61
+
62
+ def custom_chat(history, message, model_name, max_length, temperature):
63
+ return respond(history, message, model_name, max_length, temperature)
64
+
65
+ send_btn.click(
66
+ custom_chat,
67
+ inputs=[state, user_input, model_choice, max_length, temperature],
68
+ outputs=[chatbot, user_input]
69
+ )
70
 
71
  if __name__ == "__main__":
72
  demo.launch()