WillHeld commited on
Commit
b3a18de
Β·
verified Β·
1 Parent(s): 06b1cf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -80
app.py CHANGED
@@ -1,13 +1,12 @@
1
  from __future__ import annotations
2
 
3
- # -- standard lib
4
  import json
5
  import os
6
  import time
7
  import uuid
8
  from threading import Thread
 
9
 
10
- # -- third-party deps (declared in requirements.txt of the Space)
11
  import gradio as gr
12
  from gradio_modal import Modal
13
  from transformers import (
@@ -15,30 +14,32 @@ from transformers import (
15
  AutoTokenizer,
16
  TextIteratorStreamer,
17
  )
18
- from datasets import Dataset, load_dataset, concatenate_datasets, DownloadMode
 
 
 
 
 
19
  from huggingface_hub import HfApi, login
20
  import spaces
21
 
22
- # ──────────────────────────── model & constants ─────────────────────────────
23
  checkpoint = "marin-community/marin-8b-instruct"
24
- device = "cuda" # the Space runner gives us a GPU
25
 
26
- # download πŸ”₯
27
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
28
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
29
 
30
- # feedback dataset details
31
- DATASET_REPO = "WillHeld/model-feedback" # <-- change to your namespace if needed
32
  DATA_DIR = "./feedback_data"
33
  DATA_FILE = "feedback.jsonl"
34
  os.makedirs(DATA_DIR, exist_ok=True)
35
 
36
- # ──────────────────────────── helpers ───────────────────────────────────────
37
 
38
- def save_feedback_locally(conversation: list[dict[str, str]],
39
  satisfaction: str,
40
- feedback_text: str) -> str:
41
- """Append a single feedback record to a JSONL file and return its UUID."""
42
  record = {
43
  "id": str(uuid.uuid4()),
44
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
@@ -46,39 +47,25 @@ def save_feedback_locally(conversation: list[dict[str, str]],
46
  "satisfaction": satisfaction,
47
  "feedback": feedback_text,
48
  }
49
- fp = os.path.join(DATA_DIR, DATA_FILE)
50
- with open(fp, "a", encoding="utf-8") as f:
51
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
52
- return record["id"]
53
-
54
 
55
- def push_feedback_to_hub(hf_token: str | None = None) -> bool: # noqa: C901
56
- """Merge freshly collected feedback with what’s already on the Hub.
57
-
58
- Steps
59
- -----
60
- 1. Authenticate with `hf_token` (fall back to $HF_TOKEN env).
61
- 2. Load *local* feedback just written in `feedback.jsonl`.
62
- 3. Pull existing remote split (if any); concat & `unique("id")`.
63
- 4. Push the merged dataset back. Never deletes remote shards β‡’ safe.
64
- """
65
 
 
66
  hf_token = hf_token or os.getenv("HF_TOKEN")
67
  if not hf_token:
68
  print("❌ No HF token β€” skipping Hub push.")
69
  return False
70
  login(token=hf_token)
71
 
72
- fp = os.path.join(DATA_DIR, DATA_FILE)
73
- if not os.path.exists(fp):
74
- print("❌ Local feedback file missing; nothing to push.")
75
  return False
76
 
77
- # local rows β†’ Dataset
78
- with open(fp, encoding="utf-8") as f:
79
- local_ds = Dataset.from_list([json.loads(l) for l in f])
80
 
81
- # try to pull remote
82
  try:
83
  remote_ds = load_dataset(
84
  DATASET_REPO,
@@ -88,10 +75,8 @@ def push_feedback_to_hub(hf_token: str | None = None) -> bool: # noqa: C901
88
  )
89
  merged = concatenate_datasets([remote_ds, local_ds]).unique("id")
90
  except FileNotFoundError:
91
- # repo exists but empty
92
  merged = local_ds
93
  except Exception:
94
- # repo may not exist yet – create & start fresh
95
  HfApi(token=hf_token).create_repo(
96
  repo_id=DATASET_REPO, repo_type="dataset", private=True
97
  )
@@ -102,25 +87,23 @@ def push_feedback_to_hub(hf_token: str | None = None) -> bool: # noqa: C901
102
  private=True,
103
  commit_message=f"Add {len(local_ds)} new feedback entries",
104
  )
105
- print(
106
- f"βœ… Pushed {len(local_ds)} rows; dataset now has {len(merged)} total.")
107
- # (optional) clear local file once synced
108
- # os.remove(fp)
109
  return True
110
 
111
- # ──────────────────────────── chat backend ─────────────────────────────────
112
 
113
  @spaces.GPU(duration=120)
114
  def generate_response(message: str,
115
- history: list[dict[str, str]],
 
116
  temperature: float,
117
  top_p: float):
118
- """Streaming generator used by the Gradio ChatInterface."""
119
 
120
- # 1) add user message to history
121
  history.append({"role": "user", "content": message})
 
122
 
123
- # 2) build model input via chat template
124
  prompt = tokenizer.apply_chat_template(history, tokenize=False,
125
  add_generation_prompt=True)
126
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
@@ -137,79 +120,61 @@ def generate_response(message: str,
137
  streamer=streamer,
138
  )
139
 
140
- # run on a worker thread so we can yield tokens live
141
  Thread(target=model.generate, kwargs=gen_kwargs).start()
142
 
143
  partial = ""
144
  for token in streamer:
145
  partial += token
146
- yield partial, history # 1st out = msg, 2nd out = state
147
 
148
- # once finished, commit assistant reply to history
149
  history.append({"role": "assistant", "content": partial})
150
- yield partial, history
 
151
 
152
- # ──────────────────────────── feedback handler ─────────────────────────────
153
 
154
- def submit_feedback(conversation_state: list[dict[str, str]],
155
  satisfaction: str,
156
  feedback_text: str):
157
- """Callback for the *Submit Research Feedback* button."""
158
  save_feedback_locally(conversation_state, satisfaction, feedback_text)
159
- pushed = push_feedback_to_hub()
160
- if pushed:
161
  return "βœ… Thanks! Your feedback is safely stored."
162
  return "⚠️ Saved locally; Hub push failed. Check server logs."
163
 
164
- # ──────────────────────────── UI layout ────────────────────────────────────
165
 
166
- with gr.Blocks(title="Marin-8B Research Preview") as demo:
167
- # state object to surface chat history to the feedback form
168
  conversation_state = gr.State([])
169
 
170
  with gr.Row():
171
- # β€”β€”β€” Chat column β€”β€”β€”
172
  with gr.Column(scale=3):
173
  chatbot = gr.ChatInterface(
174
  fn=generate_response,
175
- additional_inputs=[conversation_state, # keeps state in sync
176
  gr.Slider(0.1, 2.0, value=0.7, step=0.1,
177
  label="Temperature"),
178
  gr.Slider(0.1, 1.0, value=0.9, step=0.05,
179
- label="Top-P")],
180
- additional_outputs=[conversation_state],
181
  type="messages",
182
  )
183
-
184
- # β€”β€”β€” Sidebar column β€”β€”β€”
185
  with gr.Column(scale=1):
186
  report_btn = gr.Button("Share Feedback", variant="primary")
187
 
188
- # feedback modal (hidden by default)
189
  with Modal(visible=False) as fb_modal:
190
  gr.Markdown("## Research Preview Feedback")
191
- gr.Markdown("We appreciate your help improving Marin-8B! ✨")
192
-
193
  sat_radio = gr.Radio([
194
  "Very satisfied", "Satisfied", "Neutral",
195
  "Unsatisfied", "Very unsatisfied"],
196
- label="Overall experience",
197
- value="Neutral",
198
- )
199
  fb_text = gr.Textbox(lines=6, label="Comments / suggestions")
200
  send_btn = gr.Button("Submit", variant="primary")
201
  status_box = gr.Textbox(label="Status", interactive=False)
202
 
203
- # interactions
204
- # open the modal without custom JS – use Modal update
205
- report_btn.click(lambda: Modal(visible=True), None, fb_modal)
206
-
207
- send_btn.click(
208
- submit_feedback,
209
- inputs=[conversation_state, sat_radio, fb_text],
210
- outputs=status_box,
211
- )
212
 
213
- # ──────────────────────────── run! ─────────────────────────────────────────
214
  if __name__ == "__main__":
215
- demo.launch()
 
1
  from __future__ import annotations
2
 
 
3
  import json
4
  import os
5
  import time
6
  import uuid
7
  from threading import Thread
8
+ from typing import List, Dict
9
 
 
10
  import gradio as gr
11
  from gradio_modal import Modal
12
  from transformers import (
 
14
  AutoTokenizer,
15
  TextIteratorStreamer,
16
  )
17
+ from datasets import (
18
+ Dataset,
19
+ load_dataset,
20
+ concatenate_datasets,
21
+ DownloadMode,
22
+ )
23
  from huggingface_hub import HfApi, login
24
  import spaces
25
 
26
+ # ─────────────────────────── model & constants ────────────────────────────
27
  checkpoint = "marin-community/marin-8b-instruct"
28
+ device = "cuda"
29
 
 
30
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
31
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
32
 
33
+ DATASET_REPO = "WillHeld/model-feedback" # change if forking
 
34
  DATA_DIR = "./feedback_data"
35
  DATA_FILE = "feedback.jsonl"
36
  os.makedirs(DATA_DIR, exist_ok=True)
37
 
38
+ # ─────────────────────────── helper functions ─────────────────────────────
39
 
40
+ def save_feedback_locally(conversation: List[Dict[str, str]],
41
  satisfaction: str,
42
+ feedback_text: str) -> None:
 
43
  record = {
44
  "id": str(uuid.uuid4()),
45
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
 
47
  "satisfaction": satisfaction,
48
  "feedback": feedback_text,
49
  }
50
+ with open(os.path.join(DATA_DIR, DATA_FILE), "a", encoding="utf-8") as fp:
51
+ fp.write(json.dumps(record, ensure_ascii=False) + "\n")
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def push_feedback_to_hub(hf_token: str | None = None) -> bool:
55
  hf_token = hf_token or os.getenv("HF_TOKEN")
56
  if not hf_token:
57
  print("❌ No HF token β€” skipping Hub push.")
58
  return False
59
  login(token=hf_token)
60
 
61
+ local_path = os.path.join(DATA_DIR, DATA_FILE)
62
+ if not os.path.exists(local_path):
63
+ print("❌ No local feedback to push.")
64
  return False
65
 
66
+ with open(local_path, encoding="utf-8") as fp:
67
+ local_ds = Dataset.from_list([json.loads(l) for l in fp])
 
68
 
 
69
  try:
70
  remote_ds = load_dataset(
71
  DATASET_REPO,
 
75
  )
76
  merged = concatenate_datasets([remote_ds, local_ds]).unique("id")
77
  except FileNotFoundError:
 
78
  merged = local_ds
79
  except Exception:
 
80
  HfApi(token=hf_token).create_repo(
81
  repo_id=DATASET_REPO, repo_type="dataset", private=True
82
  )
 
87
  private=True,
88
  commit_message=f"Add {len(local_ds)} new feedback entries",
89
  )
90
+ print(f"βœ… Pushed {len(local_ds)} rows; dataset now has {len(merged)} total.")
 
 
 
91
  return True
92
 
93
+ # ─────────────────────────── chat backend ────────────────────────────────
94
 
95
  @spaces.GPU(duration=120)
96
  def generate_response(message: str,
97
+ history: List[Dict[str, str]],
98
+ conversation_state: List[Dict[str, str]],
99
  temperature: float,
100
  top_p: float):
101
+ """Yields assistant text only; conversation_state is updated in‑place."""
102
 
103
+ # sync state
104
  history.append({"role": "user", "content": message})
105
+ conversation_state[:] = history # keep external state in sync
106
 
 
107
  prompt = tokenizer.apply_chat_template(history, tokenize=False,
108
  add_generation_prompt=True)
109
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
120
  streamer=streamer,
121
  )
122
 
 
123
  Thread(target=model.generate, kwargs=gen_kwargs).start()
124
 
125
  partial = ""
126
  for token in streamer:
127
  partial += token
128
+ yield partial # only the assistant text gets streamed
129
 
 
130
  history.append({"role": "assistant", "content": partial})
131
+ conversation_state[:] = history
132
+ # (no final yield; generator simply ends)
133
 
134
+ # ─────────────────────────── feedback handler ────────────────────────────
135
 
136
+ def submit_feedback(conversation_state: List[Dict[str, str]],
137
  satisfaction: str,
138
  feedback_text: str):
 
139
  save_feedback_locally(conversation_state, satisfaction, feedback_text)
140
+ if push_feedback_to_hub():
 
141
  return "βœ… Thanks! Your feedback is safely stored."
142
  return "⚠️ Saved locally; Hub push failed. Check server logs."
143
 
144
+ # ─────────────────────────── UI layout ───────────────────────────────────
145
 
146
+ with gr.Blocks(title="Marin‑8B Research Preview") as demo:
 
147
  conversation_state = gr.State([])
148
 
149
  with gr.Row():
 
150
  with gr.Column(scale=3):
151
  chatbot = gr.ChatInterface(
152
  fn=generate_response,
153
+ additional_inputs=[conversation_state,
154
  gr.Slider(0.1, 2.0, value=0.7, step=0.1,
155
  label="Temperature"),
156
  gr.Slider(0.1, 1.0, value=0.9, step=0.05,
157
+ label="Top‑P")],
 
158
  type="messages",
159
  )
 
 
160
  with gr.Column(scale=1):
161
  report_btn = gr.Button("Share Feedback", variant="primary")
162
 
 
163
  with Modal(visible=False) as fb_modal:
164
  gr.Markdown("## Research Preview Feedback")
165
+ gr.Markdown("We appreciate your help improving Marin‑8B! ✨")
 
166
  sat_radio = gr.Radio([
167
  "Very satisfied", "Satisfied", "Neutral",
168
  "Unsatisfied", "Very unsatisfied"],
169
+ label="Overall experience", value="Neutral")
 
 
170
  fb_text = gr.Textbox(lines=6, label="Comments / suggestions")
171
  send_btn = gr.Button("Submit", variant="primary")
172
  status_box = gr.Textbox(label="Status", interactive=False)
173
 
174
+ report_btn.click(lambda: Modal.update(visible=True), None, fb_modal)
175
+ send_btn.click(submit_feedback,
176
+ inputs=[conversation_state, sat_radio, fb_text],
177
+ outputs=status_box)
 
 
 
 
 
178
 
 
179
  if __name__ == "__main__":
180
+ demo.launch()