File size: 8,375 Bytes
5cd2aa6
 
06b1cf8
ef8a823
5cd2aa6
ef8a823
5cd2aa6
 
ef8a823
06b1cf8
5cd2aa6
bc12570
5cd2aa6
 
 
 
 
 
 
 
bc12570
06b1cf8
f5ef4f1
5cd2aa6
 
 
bc12570
 
 
06b1cf8
 
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
 
bc12570
 
5cd2aa6
bc12570
5cd2aa6
06b1cf8
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
06b1cf8
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
06b1cf8
5cd2aa6
 
 
bc12570
5cd2aa6
bc12570
5cd2aa6
 
 
bc12570
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12570
b31b98c
5cd2aa6
 
 
 
 
 
 
b31b98c
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06b1cf8
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
06b1cf8
5cd2aa6
 
 
 
06b1cf8
5cd2aa6
b31b98c
5cd2aa6
b0dd995
5cd2aa6
b0dd995
b9a25dd
5cd2aa6
 
 
 
 
06b1cf8
b31b98c
5cd2aa6
b0dd995
5cd2aa6
 
b0dd995
5cd2aa6
 
 
 
 
06b1cf8
5cd2aa6
 
 
 
 
 
 
 
 
 
 
 
06b1cf8
 
5cd2aa6
 
 
 
 
bc12570
 
5cd2aa6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from __future__ import annotations

# -- standard lib
import json
import os
import time
import uuid
from threading import Thread

# -- third-party deps (declared in requirements.txt of the Space)
import gradio as gr
from gradio_modal import Modal
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
from datasets import Dataset, load_dataset, concatenate_datasets, DownloadMode
from huggingface_hub import HfApi, login
import spaces

# ──────────────────────────── model & constants ─────────────────────────────
checkpoint = "marin-community/marin-8b-instruct"
device = "cuda"  # the Space runner gives us a GPU

# download πŸ”₯
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

# feedback dataset details
DATASET_REPO = "WillHeld/model-feedback"  # <-- change to your namespace if needed
DATA_DIR = "./feedback_data"
DATA_FILE = "feedback.jsonl"
os.makedirs(DATA_DIR, exist_ok=True)

# ──────────────────────────── helpers ───────────────────────────────────────

def save_feedback_locally(conversation: list[dict[str, str]],
                          satisfaction: str,
                          feedback_text: str) -> str:
    """Append a single feedback record to a JSONL file and return its UUID."""
    record = {
        "id": str(uuid.uuid4()),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        "conversation": conversation,
        "satisfaction": satisfaction,
        "feedback": feedback_text,
    }
    fp = os.path.join(DATA_DIR, DATA_FILE)
    with open(fp, "a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")
    return record["id"]


def push_feedback_to_hub(hf_token: str | None = None) -> bool:  # noqa: C901
    """Merge freshly collected feedback with what’s already on the Hub.

    Steps
    -----
    1.  Authenticate with `hf_token` (fall back to $HF_TOKEN env).
    2.  Load *local* feedback just written in `feedback.jsonl`.
    3.  Pull existing remote split (if any); concat & `unique("id")`.
    4.  Push the merged dataset back. Never deletes remote shards β‡’ safe.
    """

    hf_token = hf_token or os.getenv("HF_TOKEN")
    if not hf_token:
        print("❌  No HF token β€” skipping Hub push.")
        return False
    login(token=hf_token)

    fp = os.path.join(DATA_DIR, DATA_FILE)
    if not os.path.exists(fp):
        print("❌  Local feedback file missing; nothing to push.")
        return False

    # local rows β†’ Dataset
    with open(fp, encoding="utf-8") as f:
        local_ds = Dataset.from_list([json.loads(l) for l in f])

    # try to pull remote
    try:
        remote_ds = load_dataset(
            DATASET_REPO,
            split="train",
            token=hf_token,
            download_mode=DownloadMode.FORCE_REDOWNLOAD,
        )
        merged = concatenate_datasets([remote_ds, local_ds]).unique("id")
    except FileNotFoundError:
        # repo exists but empty
        merged = local_ds
    except Exception:
        # repo may not exist yet – create & start fresh
        HfApi(token=hf_token).create_repo(
            repo_id=DATASET_REPO, repo_type="dataset", private=True
        )
        merged = local_ds

    merged.push_to_hub(
        DATASET_REPO,
        private=True,
        commit_message=f"Add {len(local_ds)} new feedback entries",
    )
    print(
        f"βœ…  Pushed {len(local_ds)} rows; dataset now has {len(merged)} total.")
    # (optional) clear local file once synced
    # os.remove(fp)
    return True

# ──────────────────────────── chat backend ─────────────────────────────────

@spaces.GPU(duration=120)
def generate_response(message: str,
                      history: list[dict[str, str]],
                      temperature: float,
                      top_p: float):
    """Streaming generator used by the Gradio ChatInterface."""

    # 1) add user message to history
    history.append({"role": "user", "content": message})

    # 2) build model input via chat template
    prompt = tokenizer.apply_chat_template(history, tokenize=False,
                                           add_generation_prompt=True)
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True,
                                    skip_special_tokens=True)

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=1024,
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=True,
        streamer=streamer,
    )

    # run on a worker thread so we can yield tokens live
    Thread(target=model.generate, kwargs=gen_kwargs).start()

    partial = ""
    for token in streamer:
        partial += token
        yield partial, history  # 1st out = msg, 2nd out = state

    # once finished, commit assistant reply to history
    history.append({"role": "assistant", "content": partial})
    yield partial, history

# ──────────────────────────── feedback handler ─────────────────────────────

def submit_feedback(conversation_state: list[dict[str, str]],
                    satisfaction: str,
                    feedback_text: str):
    """Callback for the *Submit Research Feedback* button."""
    save_feedback_locally(conversation_state, satisfaction, feedback_text)
    pushed = push_feedback_to_hub()
    if pushed:
        return "βœ…  Thanks! Your feedback is safely stored."
    return "⚠️  Saved locally; Hub push failed. Check server logs."

# ──────────────────────────── UI layout ────────────────────────────────────

with gr.Blocks(title="Marin-8B Research Preview") as demo:
    # state object to surface chat history to the feedback form
    conversation_state = gr.State([])

    with gr.Row():
        # β€”β€”β€” Chat column β€”β€”β€”
        with gr.Column(scale=3):
            chatbot = gr.ChatInterface(
                fn=generate_response,
                additional_inputs=[conversation_state,  # keeps state in sync
                                   gr.Slider(0.1, 2.0, value=0.7, step=0.1,
                                             label="Temperature"),
                                   gr.Slider(0.1, 1.0, value=0.9, step=0.05,
                                             label="Top-P")],
                additional_outputs=[conversation_state],
                type="messages",
            )

        # β€”β€”β€” Sidebar column β€”β€”β€”
        with gr.Column(scale=1):
            report_btn = gr.Button("Share Feedback", variant="primary")

    # feedback modal (hidden by default)
    with Modal(visible=False) as fb_modal:
        gr.Markdown("## Research Preview Feedback")
        gr.Markdown("We appreciate your help improving Marin-8B! ✨")

        sat_radio = gr.Radio([
            "Very satisfied", "Satisfied", "Neutral",
            "Unsatisfied", "Very unsatisfied"],
            label="Overall experience",
            value="Neutral",
        )
        fb_text = gr.Textbox(lines=6, label="Comments / suggestions")
        send_btn = gr.Button("Submit", variant="primary")
        status_box = gr.Textbox(label="Status", interactive=False)

    # interactions
    # open the modal without custom JS – use Modal update
    report_btn.click(lambda: Modal(visible=True), None, fb_modal)

    send_btn.click(
        submit_feedback,
        inputs=[conversation_state, sat_radio, fb_text],
        outputs=status_box,
    )

# ──────────────────────────── run! ─────────────────────────────────────────
if __name__ == "__main__":
    demo.launch()