Spestly commited on
Commit
eeda09f
·
verified ·
1 Parent(s): fb28ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -45
app.py CHANGED
@@ -4,17 +4,34 @@ import torch
4
  import time
5
  import spaces
6
 
7
- # ZeroGPU decorator for GPU-intensive functions
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
- def load_model_gpu(model_id):
10
- """Load model on ZeroGPU"""
11
- print(f"🚀 Loading {model_id} on ZeroGPU...")
 
 
12
  start_time = time.time()
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_id,
17
- torch_dtype=torch.float16, # Use float16 for better memory efficiency
18
  device_map="auto",
19
  trust_remote_code=True
20
  )
@@ -22,15 +39,25 @@ def load_model_gpu(model_id):
22
  load_time = time.time() - start_time
23
  print(f"✅ Model loaded in {load_time:.2f}s")
24
 
25
- return model, tokenizer
26
-
27
- @spaces.GPU
28
- def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7):
29
- """Generate response using ZeroGPU"""
30
- device = next(model.parameters()).device
31
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
32
 
33
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
34
  with torch.no_grad():
35
  outputs = model.generate(
36
  **inputs,
@@ -42,26 +69,15 @@ def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7)
42
  eos_token_id=tokenizer.eos_token_id
43
  )
44
 
45
- generation_time = time.time() - start_time
46
- output_text = tokenizer.decode(
 
 
47
  outputs[0][inputs['input_ids'].shape[-1]:],
48
  skip_special_tokens=True
49
  ).strip()
50
 
51
- return output_text, generation_time
52
-
53
- # Model configurations
54
- MODELS = {
55
- "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
56
- "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
57
- "Athena-R3 7B": "Spestly/Athena-R3-7B",
58
- "Athena-3 3B": "Spestly/Athena-3-3B",
59
- "Athena-3 7B": "Spestly/Athena-3-7B",
60
- "Athena-3 14B": "Spestly/Athena-3-14B",
61
- "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
62
- "Athena-1 3B": "Spestly/Athena-1-3B",
63
- "Athena-1 7B": "Spestly/Athena-1-7B"
64
- }
65
 
66
  def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
67
  if not user_message.strip():
@@ -74,27 +90,18 @@ def chatbot(conversation, user_message, model_name, max_length=512, temperature=
74
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
75
 
76
  try:
77
- # Load model and tokenizer using ZeroGPU
78
- model, tokenizer = load_model_gpu(model_id)
79
-
80
- # Append user message to conversation
81
  conversation.append([user_message, ""])
82
 
83
- # Build prompt from conversation history
84
- prompt = ""
85
- for user_msg, assistant_msg in conversation[:-1]: # Exclude the current message
86
- prompt += f"User: {user_msg}\nAthena: {assistant_msg}\n"
87
- prompt += f"User: {user_message}\nAthena:"
88
-
89
  # Generate response using ZeroGPU
90
- output_text, generation_time = generate_response(
91
- model, tokenizer, prompt, max_length, temperature
92
  )
93
 
94
- # Update the last conversation entry with the response
95
- conversation[-1][1] = output_text
96
 
97
- stats = f" Generated in {generation_time:.2f}s | Model: {model_name} | Temp: {temperature}"
98
 
99
  return conversation, "", stats
100
 
@@ -104,7 +111,7 @@ def chatbot(conversation, user_message, model_name, max_length=512, temperature=
104
  conversation[-1][1] = error_msg
105
  else:
106
  conversation = [[user_message, error_msg]]
107
- return conversation, "", f"❌ Error occurred: {str(e)}"
108
 
109
  def clear_chat():
110
  return [], "", ""
 
4
  import time
5
  import spaces
6
 
7
+ # Model configurations
8
+ MODELS = {
9
+ "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
+ "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
11
+ "Athena-R3 7B": "Spestly/Athena-R3-7B",
12
+ "Athena-3 3B": "Spestly/Athena-3-3B",
13
+ "Athena-3 7B": "Spestly/Athena-3-7B",
14
+ "Athena-3 14B": "Spestly/Athena-3-14B",
15
+ "Athena-2 1.5B": "Spestly/Athena-2-1.5B",
16
+ "Athena-1 3B": "Spestly/Athena-1-3B",
17
+ "Athena-1 7B": "Spestly/Athena-1-7B"
18
+ }
19
+
20
  @spaces.GPU
21
+ def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
22
+ """Generate response using ZeroGPU - all CUDA operations happen here"""
23
+
24
+ # Load model and tokenizer inside the GPU function
25
+ print(f"🚀 Loading {model_id}...")
26
  start_time = time.time()
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+ if tokenizer.pad_token is None:
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
+ torch_dtype=torch.float16,
35
  device_map="auto",
36
  trust_remote_code=True
37
  )
 
39
  load_time = time.time() - start_time
40
  print(f"✅ Model loaded in {load_time:.2f}s")
41
 
42
+ # Build conversation history
43
+ conversation_history = []
44
+ for user_msg, assistant_msg in conversation:
45
+ if user_msg:
46
+ conversation_history.append(f"User: {user_msg}")
47
+ if assistant_msg:
48
+ conversation_history.append(f"Athena: {assistant_msg}")
49
 
50
+ # Add current user message
51
+ conversation_history.append(f"User: {user_message}")
52
+ conversation_history.append("Athena:")
53
+
54
+ # Create prompt
55
+ prompt = "\n".join(conversation_history)
56
+
57
+ # Tokenize and generate
58
+ inputs = tokenizer(prompt, return_tensors="pt")
59
+
60
+ generation_start = time.time()
61
  with torch.no_grad():
62
  outputs = model.generate(
63
  **inputs,
 
69
  eos_token_id=tokenizer.eos_token_id
70
  )
71
 
72
+ generation_time = time.time() - generation_start
73
+
74
+ # Decode response
75
+ response = tokenizer.decode(
76
  outputs[0][inputs['input_ids'].shape[-1]:],
77
  skip_special_tokens=True
78
  ).strip()
79
 
80
+ return response, load_time, generation_time
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
83
  if not user_message.strip():
 
90
  model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
91
 
92
  try:
93
+ # Add user message to conversation
 
 
 
94
  conversation.append([user_message, ""])
95
 
 
 
 
 
 
 
96
  # Generate response using ZeroGPU
97
+ response, load_time, generation_time = generate_response(
98
+ model_id, conversation[:-1], user_message, max_length, temperature
99
  )
100
 
101
+ # Update the conversation with the response
102
+ conversation[-1][1] = response
103
 
104
+ stats = f"��� Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}"
105
 
106
  return conversation, "", stats
107
 
 
111
  conversation[-1][1] = error_msg
112
  else:
113
  conversation = [[user_message, error_msg]]
114
+ return conversation, "", f"❌ Error: {str(e)}"
115
 
116
  def clear_chat():
117
  return [], "", ""