Spestly commited on
Commit
da8de8d
·
verified ·
1 Parent(s): 739239b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -12
app.py CHANGED
@@ -4,42 +4,128 @@ import torch
4
  import time
5
  import spaces
6
 
 
7
  MODELS = {
8
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
9
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
10
- # ... other models ...
 
 
 
 
 
 
11
  }
12
 
13
  @spaces.GPU
14
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
15
- # [Same as your function]
16
- # ... code omitted for brevity ...
17
- pass # Use your original code here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def respond(history, message, model_name, max_length, temperature):
 
20
  if not message.strip():
21
- return history + [["user", message], ["assistant", "Please enter a message"]], ""
 
22
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
23
  try:
24
- response, _, _ = generate_response(model_id, history, message, max_length, temperature)
 
 
 
 
 
 
 
 
 
 
 
 
25
  history = history + [["user", message], ["assistant", response]]
26
  return history, ""
27
  except Exception as e:
28
  history = history + [["user", message], ["assistant", f"Error: {str(e)}"]]
29
  return history, ""
30
 
31
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
32
  gr.Markdown("# 🚀 Athena Playground Chat")
33
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
34
 
35
- chatbot = gr.Chatbot(height=500)
36
  state = gr.State([]) # chat history
37
 
38
  with gr.Row():
39
- user_input = gr.Textbox(label="Your message", scale=8)
40
  send_btn = gr.Button(value="Send", scale=1)
41
 
42
- # Place settings at the bottom!
43
  gr.Markdown("### ⚙️ Model & Generation Settings")
44
  with gr.Row():
45
  model_choice = gr.Dropdown(
@@ -59,11 +145,11 @@ with gr.Blocks() as demo:
59
  info="Higher values = more creative responses"
60
  )
61
 
62
- def custom_chat(history, message, model_name, max_length, temperature):
63
  return respond(history, message, model_name, max_length, temperature)
64
 
65
  send_btn.click(
66
- custom_chat,
67
  inputs=[state, user_input, model_choice, max_length, temperature],
68
  outputs=[chatbot, user_input]
69
  )
 
4
  import time
5
  import spaces
6
 
7
+ # Model configurations
8
  MODELS = {
9
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
11
+ "Athena-R3 7B": "Spestly/Athena-R3-7B",
12
+ "Athena-3 3B": "Spestly/Athena-3-3B",
13
+ "Athena-3 7B": "Spestly/Athena-3-7B",
14
+ "Athena-3 14B": "Spestly/Athena-3-14B",
15
+ "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
16
+ "Athena-1 3B": "Spestly/Athena-1-3B",
17
+ "Athena-1 7B": "Spestly/Athena-1-7B"
18
  }
19
 
20
  @spaces.GPU
21
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
22
+ """Generate response using ZeroGPU - all CUDA operations happen here"""
23
+ print(f"🚀 Loading {model_id}...")
24
+ start_time = time.time()
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ torch_dtype=torch.float16,
31
+ device_map="auto",
32
+ trust_remote_code=True
33
+ )
34
+ load_time = time.time() - start_time
35
+ print(f"✅ Model loaded in {load_time:.2f}s")
36
+
37
+ # Build messages in proper chat format (OpenAI-style messages)
38
+ messages = []
39
+ system_prompt = (
40
+ "You are Athena, a helpful, harmless, and honest AI assistant. "
41
+ "You provide clear, accurate, and concise responses to user questions. "
42
+ "You are knowledgeable across many domains and always aim to be respectful and helpful. "
43
+ "You are finetuned by Aayan Mishra"
44
+ )
45
+ messages.append({"role": "system", "content": system_prompt})
46
+
47
+ # Add conversation history (OpenAI-style)
48
+ for msg in conversation:
49
+ if msg["role"] in ("user", "assistant"):
50
+ messages.append({"role": msg["role"], "content": msg["content"]})
51
+
52
+ # Add current user message
53
+ messages.append({"role": "user", "content": user_message})
54
+
55
+ prompt = tokenizer.apply_chat_template(
56
+ messages,
57
+ tokenize=False,
58
+ add_generation_prompt=True
59
+ )
60
+ inputs = tokenizer(prompt, return_tensors="pt")
61
+ device = next(model.parameters()).device
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+ generation_start = time.time()
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ **inputs,
67
+ max_new_tokens=max_length,
68
+ temperature=temperature,
69
+ do_sample=True,
70
+ top_p=0.9,
71
+ pad_token_id=tokenizer.eos_token_id,
72
+ eos_token_id=tokenizer.eos_token_id
73
+ )
74
+ generation_time = time.time() - generation_start
75
+ response = tokenizer.decode(
76
+ outputs[0][inputs['input_ids'].shape[-1]:],
77
+ skip_special_tokens=True
78
+ ).strip()
79
+ return response, load_time, generation_time
80
 
81
  def respond(history, message, model_name, max_length, temperature):
82
+ """Main function for custom Chatbot interface"""
83
  if not message.strip():
84
+ history = history + [["user", message], ["assistant", "Please enter a message"]]
85
+ return history, ""
86
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
87
  try:
88
+ # Format history for Athena
89
+ formatted_history = []
90
+ for i in range(0, len(history), 2):
91
+ if i < len(history):
92
+ user_msg = history[i][1] if history[i][0] == "user" else ""
93
+ assistant_msg = history[i+1][1] if i+1 < len(history) and history[i+1][0] == "assistant" else ""
94
+ if user_msg:
95
+ formatted_history.append({"role": "user", "content": user_msg})
96
+ if assistant_msg:
97
+ formatted_history.append({"role": "assistant", "content": assistant_msg})
98
+ response, load_time, generation_time = generate_response(
99
+ model_id, formatted_history, message, max_length, temperature
100
+ )
101
  history = history + [["user", message], ["assistant", response]]
102
  return history, ""
103
  except Exception as e:
104
  history = history + [["user", message], ["assistant", f"Error: {str(e)}"]]
105
  return history, ""
106
 
107
+ css = """
108
+ .message {
109
+ padding: 10px;
110
+ margin: 5px;
111
+ border-radius: 10px;
112
+ }
113
+ """
114
+
115
+ theme = gr.themes.Monochrome()
116
+
117
+ with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
118
  gr.Markdown("# 🚀 Athena Playground Chat")
119
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
120
 
121
+ chatbot = gr.Chatbot(height=500, label="Athena", avatar="🤖")
122
  state = gr.State([]) # chat history
123
 
124
  with gr.Row():
125
+ user_input = gr.Textbox(label="Your message", scale=8, autofocus=True)
126
  send_btn = gr.Button(value="Send", scale=1)
127
 
128
+ # --- Configuration controls at the bottom ---
129
  gr.Markdown("### ⚙️ Model & Generation Settings")
130
  with gr.Row():
131
  model_choice = gr.Dropdown(
 
145
  info="Higher values = more creative responses"
146
  )
147
 
148
+ def chat_submit(history, message, model_name, max_length, temperature):
149
  return respond(history, message, model_name, max_length, temperature)
150
 
151
  send_btn.click(
152
+ chat_submit,
153
  inputs=[state, user_input, model_choice, max_length, temperature],
154
  outputs=[chatbot, user_input]
155
  )