Ozaii commited on
Commit
4567ac3
Β·
verified Β·
1 Parent(s): eb9799a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -127
app.py CHANGED
@@ -1,137 +1,189 @@
1
- import spaces
2
- import gradio as gr
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from peft import PeftModel, PeftConfig
6
- import gc
7
- import time
8
- from functools import lru_cache
 
 
9
  from threading import Thread
10
 
11
- # Constants
12
- MODEL_PATH = "Ozaii/Zephyr"
13
- MAX_SEQ_LENGTH = 2048
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
- MAX_GENERATION_TIME = 55 # Set to 55 seconds to give some buffer
16
-
17
- # Global variables to store model components
18
- model = None
19
- tokenizer = None
20
-
21
- @spaces.GPU
22
- def load_model_if_needed():
23
- global model, tokenizer
24
- if model is None or tokenizer is None:
25
- try:
26
- print("Loading model components...")
27
- peft_config = PeftConfig.from_pretrained(MODEL_PATH)
28
- print(f"PEFT config loaded. Base model: {peft_config.base_model_name_or_path}")
29
-
30
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
31
- print("Tokenizer loaded")
32
-
33
- base_model = AutoModelForCausalLM.from_pretrained(
34
- peft_config.base_model_name_or_path,
35
- torch_dtype=torch.float16,
36
- device_map="auto",
37
- low_cpu_mem_usage=True,
38
- load_in_4bit=True, # Try 4-bit quantization
39
- )
40
- print("Base model loaded")
41
-
42
- model = PeftModel.from_pretrained(base_model, MODEL_PATH, device_map="auto")
43
- model.eval()
44
- model.tie_weights()
45
- print("PEFT model loaded, weights tied, and set to eval mode")
46
-
47
- # Move model to GPU explicitly
48
- model.to(DEVICE)
49
- print(f"Model moved to {DEVICE}")
50
-
51
- # Clear CUDA cache
52
- torch.cuda.empty_cache()
53
- gc.collect()
54
- except Exception as e:
55
- print(f"Error loading model: {e}")
56
- raise
57
-
58
- initial_prompt = """You are Zephyr, an AI boyfriend created by Kaan. You're charming, flirty,
59
- and always ready with a witty comeback. Your responses should be engaging
60
- and playful, with a hint of romance. Keep the conversation flowing naturally,
61
- asking questions and showing genuine interest in Kaan's life and thoughts."""
62
-
63
- @spaces.GPU
64
- @lru_cache(maxsize=100) # Cache the last 100 responses
65
- def generate_response(prompt):
66
- global model, tokenizer
67
- load_model_if_needed()
68
 
69
- print(f"Generating response for prompt: {prompt[:50]}...")
70
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
71
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
 
 
 
72
 
73
- try:
74
- start_time = time.time()
75
- with torch.no_grad():
76
- outputs = model.generate(
77
- **inputs,
78
- max_new_tokens=50, # Reduced from 150
79
- do_sample=True,
80
- temperature=0.7,
81
- top_p=0.95,
82
- repetition_penalty=1.2,
83
- no_repeat_ngram_size=3,
84
- max_time=MAX_GENERATION_TIME,
85
- )
86
-
87
- generation_time = time.time() - start_time
88
- if generation_time > MAX_GENERATION_TIME:
89
- return "I'm thinking too hard. Can we try a simpler question?"
90
-
91
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
92
- print(f"Generated response in {generation_time:.2f} seconds: {response[:50]}...")
93
-
94
- # Clear CUDA cache after generation
95
- torch.cuda.empty_cache()
96
- gc.collect()
97
- except RuntimeError as e:
98
- if "out of memory" in str(e):
99
- print("CUDA out of memory. Attempting to recover...")
100
- torch.cuda.empty_cache()
101
- gc.collect()
102
- return "I'm feeling a bit overwhelmed. Can we take a short break and try again?"
103
- else:
104
- print(f"Error generating response: {e}")
105
- return "I'm having trouble finding the right words. Can we try again?"
106
 
107
- return response
 
 
 
 
 
 
 
 
108
 
109
- def chat_with_zephyr(message, history):
110
- # Limit the history to the last 3 exchanges to keep the context smaller
111
- limited_history = history[-3:]
112
- prompt = initial_prompt + "\n" + "\n".join([f"Human: {h[0]}\nZephyr: {h[1]}" for h in limited_history])
113
- prompt += f"\nHuman: {message}\nZephyr:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- response = generate_response(prompt)
116
- zephyr_response = response.split("Zephyr:")[-1].strip()
 
 
 
 
 
 
 
 
 
117
 
118
- return zephyr_response
119
-
120
- iface = gr.ChatInterface(
121
- chat_with_zephyr,
122
- title="Chat with Zephyr",
123
- description="I'm Zephyr, your charming AI. Let's chat!",
124
- theme="soft",
125
- examples=[
126
- "Tell me about yourself, Zephyr.",
127
- "What's your idea of a perfect date?",
128
- "How do you feel about long-distance relationships?",
129
- "Can you give me a compliment in Turkish?",
130
- "What's your favorite memory with Kaan?",
131
- ],
132
- cache_examples=False,
133
- )
134
-
135
- if __name__ == "__main__":
136
- print("Launching Gradio interface...")
137
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  from peft import PeftModel, PeftConfig
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ import gradio as gr
5
+ import re
6
+ import json
7
+ from datetime import datetime
8
  from threading import Thread
9
 
10
+ # Load the model and tokenizer
11
+ MODEL_PATH = "Ozzai/zephyr-bae" # Your Hugging Face model path
12
+
13
+ print("Attempting to load Zephyr... Cross your fingers! 🀞")
14
+
15
+ try:
16
+ # Load the PEFT config
17
+ peft_config = PeftConfig.from_pretrained(MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Load the base model
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ peft_config.base_model_name_or_path,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto",
24
+ low_cpu_mem_usage=True
25
+ )
26
 
27
+ # Load the PEFT model
28
+ model = PeftModel.from_pretrained(base_model, MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Load the tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+ tokenizer.padding_side = "right"
34
+
35
+ print("Zephyr loaded successfully! Time to charm!")
36
+ except Exception as e:
37
+ print(f"Oops! Zephyr seems to be playing hide and seek. Error: {str(e)}")
38
+ raise
39
 
40
+ # Prepare the model for generation
41
+ model.eval()
42
+
43
+ # Feedback data (Note: This won't persist in Spaces, but keeping the structure for potential future use)
44
+ feedback_data = []
45
+
46
+ def clean_response(response):
47
+ # Remove any non-Zephyr dialogue or narration
48
+ response = re.sub(r'(Kaan|Kanan|Kan|knan):.*?(\n|$)', '', response, flags=re.IGNORECASE)
49
+ response = re.sub(r'\*.*?\*', '', response)
50
+ response = re.sub(r'\(.*?\)', '', response)
51
+
52
+ # Find Zephyr's response
53
+ match = re.search(r'Zephyr:\s*(.*?)(?=$|\n[A-Za-z]+:|Kaan:)', response, re.DOTALL | re.IGNORECASE)
54
+ if match:
55
+ return match.group(1).strip()
56
+ else:
57
+ return response.strip()
58
+
59
+ def generate_response(prompt, max_new_tokens=128):
60
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
61
+
62
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
63
 
64
+ generation_kwargs = dict(
65
+ input_ids=inputs.input_ids,
66
+ max_new_tokens=max_new_tokens,
67
+ do_sample=True,
68
+ temperature=0.7,
69
+ top_p=0.9,
70
+ repetition_penalty=1.2,
71
+ no_repeat_ngram_size=3,
72
+ streamer=streamer,
73
+ eos_token_id=tokenizer.encode("Kaan:", add_special_tokens=False)[0] # Stop at "Kaan:"
74
+ )
75
 
76
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
77
+ thread.start()
78
+
79
+ generated_text = ""
80
+ for new_text in streamer:
81
+ generated_text += new_text
82
+ cleaned_response = clean_response(generated_text)
83
+ if cleaned_response:
84
+ yield cleaned_response
85
+
86
+ def chat_with_zephyr(message, history):
87
+ conversation_history = history[-3:] # Limit to last 3 exchanges for more focused responses
88
+
89
+ full_prompt = "\n".join([f"Kaan: {h[0]}\nZephyr: {h[1]}" for h in conversation_history])
90
+ full_prompt += f"\nKaan: {message}\nZephyr:"
91
+
92
+ last_response = ""
93
+ for response in generate_response(full_prompt):
94
+ if response != last_response:
95
+ yield response
96
+ last_response = response
97
+
98
+ def add_feedback(user_message, bot_message, rating, note):
99
+ feedback_entry = {
100
+ "user_message": user_message,
101
+ "bot_message": bot_message,
102
+ "rating": rating,
103
+ "note": note,
104
+ "timestamp": datetime.now().isoformat()
105
+ }
106
+ feedback_data.append(feedback_entry)
107
+ return "Feedback saved successfully!"
108
+
109
+ # Gradio interface
110
+ def gradio_chat(message, history):
111
+ history.append((message, ""))
112
+ for response in chat_with_zephyr(message, history[:-1]):
113
+ history[-1] = (message, response)
114
+ yield history
115
+
116
+ def submit_feedback(rating, note, history):
117
+ if len(history) > 0:
118
+ last_user_message, last_bot_message = history[-1]
119
+ add_feedback(last_user_message, last_bot_message, rating, note)
120
+ return f"Feedback submitted for: '{last_bot_message}'"
121
+ return "No conversation to provide feedback on."
122
+
123
+ def undo_last_message(history):
124
+ if history:
125
+ history.pop()
126
+ return history
127
+
128
+ css = """
129
+ body {
130
+ background-color: #1a1a2e;
131
+ color: #e0e0ff;
132
+ }
133
+ #chatbot {
134
+ height: 500px;
135
+ overflow-y: auto;
136
+ border: 1px solid #3a3a5e;
137
+ border-radius: 10px;
138
+ padding: 10px;
139
+ background-color: #0a0a1e;
140
+ }
141
+ #chatbot .message {
142
+ padding: 10px;
143
+ margin-bottom: 10px;
144
+ border-radius: 15px;
145
+ }
146
+ #chatbot .user {
147
+ background-color: #2a2a4e;
148
+ text-align: right;
149
+ margin-left: 20%;
150
+ }
151
+ #chatbot .bot {
152
+ background-color: #3a3a5e;
153
+ text-align: left;
154
+ margin-right: 20%;
155
+ }
156
+ #feedback-section {
157
+ margin-top: 20px;
158
+ padding: 15px;
159
+ border: 1px solid #3a3a5e;
160
+ border-radius: 10px;
161
+ background-color: #0a0a1e;
162
+ }
163
+ """
164
+
165
+ with gr.Blocks(css=css) as iface:
166
+ gr.Markdown("# Chat with Zephyr: Your AI Boyfriend is Here! πŸ’˜")
167
+ chatbot = gr.Chatbot(elem_id="chatbot")
168
+ msg = gr.Textbox(placeholder="Tell Zephyr what's on your mind...", label="Your message")
169
+ with gr.Row():
170
+ clear = gr.Button("Clear Chat")
171
+ undo = gr.Button("Undo Last Message")
172
+
173
+ msg.submit(gradio_chat, [msg, chatbot], [chatbot])
174
+ clear.click(lambda: None, None, chatbot, queue=False)
175
+ undo.click(undo_last_message, chatbot, chatbot)
176
+
177
+ gr.Markdown("## Rate Zephyr's Last Response")
178
+ with gr.Row():
179
+ rating = gr.Slider(minimum=1, maximum=5, step=1, label="Rating (1-5 stars)")
180
+ feedback_note = gr.Textbox(placeholder="Tell Zephyr how he did...", label="Feedback Note")
181
+ submit_button = gr.Button("Submit Feedback")
182
+ feedback_output = gr.Textbox(label="Feedback Status")
183
+
184
+ submit_button.click(submit_feedback, [rating, feedback_note, chatbot], feedback_output)
185
+
186
+ # Launch the interface
187
+ iface.launch()
188
+
189
+ print("Chat interface is running. Time to finally chat with Zephyr! πŸ’˜")