WillHeld commited on
Commit
77129be
Β·
verified Β·
1 Parent(s): b3a18de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -155
app.py CHANGED
@@ -1,180 +1,220 @@
1
- from __future__ import annotations
2
-
3
- import json
 
4
  import os
5
- import time
6
  import uuid
7
- from threading import Thread
8
- from typing import List, Dict
 
9
 
10
- import gradio as gr
11
  from gradio_modal import Modal
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoTokenizer,
15
- TextIteratorStreamer,
16
- )
17
- from datasets import (
18
- Dataset,
19
- load_dataset,
20
- concatenate_datasets,
21
- DownloadMode,
22
- )
23
- from huggingface_hub import HfApi, login
24
- import spaces
25
 
26
- # ─────────────────────────── model & constants ────────────────────────────
27
- checkpoint = "marin-community/marin-8b-instruct"
28
  device = "cuda"
29
-
30
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
31
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
32
 
33
- DATASET_REPO = "WillHeld/model-feedback" # change if forking
34
- DATA_DIR = "./feedback_data"
35
- DATA_FILE = "feedback.jsonl"
36
- os.makedirs(DATA_DIR, exist_ok=True)
37
-
38
- # ─────────────────────────── helper functions ─────────────────────────────
39
-
40
- def save_feedback_locally(conversation: List[Dict[str, str]],
41
- satisfaction: str,
42
- feedback_text: str) -> None:
43
- record = {
44
- "id": str(uuid.uuid4()),
45
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
 
 
 
 
 
 
 
 
46
  "conversation": conversation,
47
  "satisfaction": satisfaction,
48
- "feedback": feedback_text,
49
  }
50
- with open(os.path.join(DATA_DIR, DATA_FILE), "a", encoding="utf-8") as fp:
51
- fp.write(json.dumps(record, ensure_ascii=False) + "\n")
52
-
53
-
54
- def push_feedback_to_hub(hf_token: str | None = None) -> bool:
55
- hf_token = hf_token or os.getenv("HF_TOKEN")
56
- if not hf_token:
57
- print("❌ No HF token β€” skipping Hub push.")
58
- return False
59
- login(token=hf_token)
60
-
61
- local_path = os.path.join(DATA_DIR, DATA_FILE)
62
- if not os.path.exists(local_path):
63
- print("❌ No local feedback to push.")
64
- return False
65
-
66
- with open(local_path, encoding="utf-8") as fp:
67
- local_ds = Dataset.from_list([json.loads(l) for l in fp])
68
-
69
  try:
70
- remote_ds = load_dataset(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  DATASET_REPO,
72
- split="train",
73
- token=hf_token,
74
- download_mode=DownloadMode.FORCE_REDOWNLOAD,
75
- )
76
- merged = concatenate_datasets([remote_ds, local_ds]).unique("id")
77
- except FileNotFoundError:
78
- merged = local_ds
79
- except Exception:
80
- HfApi(token=hf_token).create_repo(
81
- repo_id=DATASET_REPO, repo_type="dataset", private=True
82
  )
83
- merged = local_ds
84
-
85
- merged.push_to_hub(
86
- DATASET_REPO,
87
- private=True,
88
- commit_message=f"Add {len(local_ds)} new feedback entries",
89
- )
90
- print(f"βœ… Pushed {len(local_ds)} rows; dataset now has {len(merged)} total.")
91
- return True
92
-
93
- # ─────────────────────────── chat backend ────────────────────────────────
94
 
 
95
  @spaces.GPU(duration=120)
96
- def generate_response(message: str,
97
- history: List[Dict[str, str]],
98
- conversation_state: List[Dict[str, str]],
99
- temperature: float,
100
- top_p: float):
101
- """Yields assistant text only; conversation_state is updated in‑place."""
102
-
103
- # sync state
104
  history.append({"role": "user", "content": message})
105
- conversation_state[:] = history # keep external state in sync
106
-
107
- prompt = tokenizer.apply_chat_template(history, tokenize=False,
108
- add_generation_prompt=True)
109
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
110
-
111
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True,
112
- skip_special_tokens=True)
113
-
114
- gen_kwargs = dict(
115
- input_ids=input_ids,
116
- max_new_tokens=1024,
117
- temperature=float(temperature),
118
- top_p=float(top_p),
119
- do_sample=True,
120
- streamer=streamer,
121
- )
122
-
123
- Thread(target=model.generate, kwargs=gen_kwargs).start()
124
-
125
- partial = ""
126
- for token in streamer:
127
- partial += token
128
- yield partial # only the assistant text gets streamed
129
-
130
- history.append({"role": "assistant", "content": partial})
131
- conversation_state[:] = history
132
- # (no final yield; generator simply ends)
133
-
134
- # ─────────────────────────── feedback handler ────────────────────────────
135
-
136
- def submit_feedback(conversation_state: List[Dict[str, str]],
137
- satisfaction: str,
138
- feedback_text: str):
139
- save_feedback_locally(conversation_state, satisfaction, feedback_text)
140
- if push_feedback_to_hub():
141
- return "βœ… Thanks! Your feedback is safely stored."
142
- return "⚠️ Saved locally; Hub push failed. Check server logs."
143
-
144
- # ─────────────────────────── UI layout ───────────────────────────────────
145
-
146
- with gr.Blocks(title="Marin‑8B Research Preview") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
147
  conversation_state = gr.State([])
148
-
149
  with gr.Row():
150
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
 
151
  chatbot = gr.ChatInterface(
152
- fn=generate_response,
153
- additional_inputs=[conversation_state,
154
- gr.Slider(0.1, 2.0, value=0.7, step=0.1,
155
- label="Temperature"),
156
- gr.Slider(0.1, 1.0, value=0.9, step=0.05,
157
- label="Top‑P")],
158
- type="messages",
 
159
  )
 
160
  with gr.Column(scale=1):
161
- report_btn = gr.Button("Share Feedback", variant="primary")
162
-
163
- with Modal(visible=False) as fb_modal:
164
- gr.Markdown("## Research Preview Feedback")
165
- gr.Markdown("We appreciate your help improving Marin‑8B! ✨")
166
- sat_radio = gr.Radio([
167
- "Very satisfied", "Satisfied", "Neutral",
168
- "Unsatisfied", "Very unsatisfied"],
169
- label="Overall experience", value="Neutral")
170
- fb_text = gr.Textbox(lines=6, label="Comments / suggestions")
171
- send_btn = gr.Button("Submit", variant="primary")
172
- status_box = gr.Textbox(label="Status", interactive=False)
173
-
174
- report_btn.click(lambda: Modal.update(visible=True), None, fb_modal)
175
- send_btn.click(submit_feedback,
176
- inputs=[conversation_state, sat_radio, fb_text],
177
- outputs=status_box)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- if __name__ == "__main__":
180
- demo.launch()
 
1
+ import spaces
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ import gradio as gr
4
+ from threading import Thread
5
  import os
6
+ import json
7
  import uuid
8
+ from datasets import Dataset
9
+ from huggingface_hub import HfApi, login
10
+ import time
11
 
12
+ # Install required packages if not present
13
  from gradio_modal import Modal
14
+ import huggingface_hub
15
+ import datasets
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Model setup
18
+ checkpoint = "WillHeld/soft-raccoon"
19
  device = "cuda"
 
20
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
21
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
22
 
23
+ # Constants for dataset
24
+ DATASET_REPO = "WillHeld/model-feedback" # Replace with your username
25
+ DATASET_PATH = "./feedback_data" # Local path to store feedback
26
+ DATASET_FILENAME = "feedback.jsonl" # Filename for feedback data
27
+
28
+ # Ensure feedback directory exists
29
+ os.makedirs(DATASET_PATH, exist_ok=True)
30
+
31
+ # Feedback storage functions
32
+ def save_feedback_locally(conversation, satisfaction, feedback_text):
33
+ """Save feedback to a local JSONL file"""
34
+ # Create a unique ID for this feedback entry
35
+ feedback_id = str(uuid.uuid4())
36
+
37
+ # Create a timestamp
38
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
39
+
40
+ # Prepare the feedback data
41
+ feedback_data = {
42
+ "id": feedback_id,
43
+ "timestamp": timestamp,
44
  "conversation": conversation,
45
  "satisfaction": satisfaction,
46
+ "feedback": feedback_text
47
  }
48
+
49
+ # Save to local file
50
+ feedback_file = os.path.join(DATASET_PATH, DATASET_FILENAME)
51
+ with open(feedback_file, "a") as f:
52
+ f.write(json.dumps(feedback_data) + "\n")
53
+
54
+ return feedback_id
55
+
56
+ def push_feedback_to_hub(hf_token=None):
57
+ """Push the local feedback data to HuggingFace as a dataset"""
58
+ # Check if we have a token
59
+ if hf_token is None:
60
+ # Try to get token from environment variable
61
+ hf_token = os.environ.get("HF_TOKEN")
62
+ if hf_token is None:
63
+ print("No HuggingFace token provided. Cannot push to Hub.")
64
+ return False
65
+
 
66
  try:
67
+ # Login to HuggingFace
68
+ login(token=hf_token)
69
+
70
+ # Check if we have data to push
71
+ feedback_file = os.path.join(DATASET_PATH, DATASET_FILENAME)
72
+ if not os.path.exists(feedback_file):
73
+ print("No feedback data to push.")
74
+ return False
75
+
76
+ # Load data from the JSONL file
77
+ with open(feedback_file, "r") as f:
78
+ feedback_data = [json.loads(line) for line in f]
79
+
80
+ # Create a dataset from the feedback data
81
+ dataset = Dataset.from_list(feedback_data)
82
+
83
+ # Push to Hub
84
+ dataset.push_to_hub(
85
  DATASET_REPO,
86
+ private=True # Set to False if you want the dataset to be public
 
 
 
 
 
 
 
 
 
87
  )
88
+
89
+ print(f"Feedback data pushed to {DATASET_REPO} successfully.")
90
+ return True
91
+
92
+ except Exception as e:
93
+ print(f"Error pushing feedback data to Hub: {e}")
94
+ return False
 
 
 
 
95
 
96
+ # Modified predict function to update conversation state
97
  @spaces.GPU(duration=120)
98
+ def predict(message, history, temperature, top_p):
99
+ # Update history with user message
 
 
 
 
 
 
100
  history.append({"role": "user", "content": message})
101
+
102
+
103
+ input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
104
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
105
+
106
+ # Create a streamer
107
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
108
+
109
+ # Set up generation parameters
110
+ generation_kwargs = {
111
+ "input_ids": inputs,
112
+ "max_new_tokens": 1024,
113
+ "temperature": float(temperature),
114
+ "top_p": float(top_p),
115
+ "do_sample": True,
116
+ "streamer": streamer,
117
+ }
118
+
119
+ # Run generation in a separate thread
120
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
121
+ thread.start()
122
+
123
+ # Yield from the streamer as tokens are generated
124
+ partial_text = ""
125
+ for new_text in streamer:
126
+ partial_text += new_text
127
+ yield partial_text, state
128
+
129
+ # After full generation, update state with assistant's response
130
+ history.append({"role": "assistant", "content": partial_text})
131
+ return partial_text
132
+
133
+ # Function to handle the research feedback submission
134
+ def submit_research_feedback(conversation_state, satisfaction, feedback_text):
135
+ """Save user feedback both locally and to HuggingFace Hub"""
136
+ # Save locally first
137
+ feedback_id = save_feedback_locally(conversation_state, satisfaction, feedback_text)
138
+
139
+ # Get token from environment variable
140
+ env_token = os.environ.get("HF_TOKEN")
141
+
142
+ # Use environment token
143
+ push_success = push_feedback_to_hub(env_token)
144
+
145
+ if push_success:
146
+ status_msg = "Thank you for your valuable feedback! Your insights have been saved to the dataset."
147
+ else:
148
+ status_msg = "Thank you for your feedback! It has been saved locally, but couldn't be pushed to the dataset. Please check server logs."
149
+
150
+ return status_msg
151
+
152
+ # Create the Gradio blocks interface
153
+ with gr.Blocks() as demo:
154
+ # State to track conversation history
155
  conversation_state = gr.State([])
156
+
157
  with gr.Row():
158
  with gr.Column(scale=3):
159
+ # Custom chat function wrapper to update state
160
+ def chat_with_state(message, history, state, temperature, top_p):
161
+ for partial_response, updated_state in predict(message, history, temperature, top_p):
162
+ # Update our state with each yield
163
+ state = history.copy()
164
+ yield partial_response, state
165
+ state = history.copy()
166
+ print(state)
167
+ return partial_response, state
168
+
169
+ # Create ChatInterface
170
  chatbot = gr.ChatInterface(
171
+ chat_with_state,
172
+ additional_inputs=[
173
+ conversation_state,
174
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
175
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
176
+ ],
177
+ additional_outputs=[conversation_state],
178
+ type="messages"
179
  )
180
+
181
  with gr.Column(scale=1):
182
+ report_button = gr.Button("Share Feedback", variant="primary")
183
+
184
+ # Create the modal with feedback form components
185
+ with Modal(visible=False) as feedback_modal:
186
+ with gr.Column():
187
+ gr.Markdown("## Research Preview Feedback")
188
+ gr.Markdown("Thank you for testing our research model. Your feedback (positive or negative) helps us improve!")
189
+
190
+ satisfaction = gr.Radio(
191
+ ["Very satisfied", "Satisfied", "Neutral", "Unsatisfied", "Very unsatisfied"],
192
+ label="How would you rate your experience with this research model?",
193
+ value="Neutral"
194
+ )
195
+
196
+ feedback_text = gr.Textbox(
197
+ lines=5,
198
+ label="Share your observations (strengths, weaknesses, suggestions):",
199
+ placeholder="We welcome both positive feedback and constructive criticism to help improve this research prototype..."
200
+ )
201
+
202
+ submit_button = gr.Button("Submit Research Feedback", variant="primary")
203
+ response_text = gr.Textbox(label="Status", interactive=False)
204
+
205
+ # Connect the "Share Feedback" button to show the modal
206
+ report_button.click(
207
+ lambda: Modal(visible=True),
208
+ None,
209
+ feedback_modal
210
+ )
211
+
212
+ # Connect the submit button to the submit_research_feedback function with the current conversation state
213
+ submit_button.click(
214
+ submit_research_feedback,
215
+ inputs=[conversation_state, satisfaction, feedback_text],
216
+ outputs=response_text
217
+ )
218
 
219
+ # Launch the demo
220
+ demo.launch()