Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
# -- standard lib | |
import json | |
import os | |
import time | |
import uuid | |
from threading import Thread | |
# -- third-party deps (declared in requirements.txt of the Space) | |
import gradio as gr | |
from gradio_modal import Modal | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
from datasets import Dataset, load_dataset, concatenate_datasets, DownloadMode | |
from huggingface_hub import HfApi, login | |
import spaces | |
# ββββββββββββββββββββββββββββ model & constants βββββββββββββββββββββββββββββ | |
checkpoint = "marin-community/marin-8b-instruct" | |
device = "cuda" # the Space runner gives us a GPU | |
# download π₯ | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# feedback dataset details | |
DATASET_REPO = "WillHeld/model-feedback" # <-- change to your namespace if needed | |
DATA_DIR = "./feedback_data" | |
DATA_FILE = "feedback.jsonl" | |
os.makedirs(DATA_DIR, exist_ok=True) | |
# ββββββββββββββββββββββββββββ helpers βββββββββββββββββββββββββββββββββββββββ | |
def save_feedback_locally(conversation: list[dict[str, str]], | |
satisfaction: str, | |
feedback_text: str) -> str: | |
"""Append a single feedback record to a JSONL file and return its UUID.""" | |
record = { | |
"id": str(uuid.uuid4()), | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), | |
"conversation": conversation, | |
"satisfaction": satisfaction, | |
"feedback": feedback_text, | |
} | |
fp = os.path.join(DATA_DIR, DATA_FILE) | |
with open(fp, "a", encoding="utf-8") as f: | |
f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
return record["id"] | |
def push_feedback_to_hub(hf_token: str | None = None) -> bool: # noqa: C901 | |
"""Merge freshly collected feedback with whatβs already on the Hub. | |
Steps | |
----- | |
1. Authenticate with `hf_token` (fall back to $HF_TOKEN env). | |
2. Load *local* feedback just written in `feedback.jsonl`. | |
3. Pull existing remote split (if any); concat & `unique("id")`. | |
4. Push the merged dataset back. Never deletes remote shards β safe. | |
""" | |
hf_token = hf_token or os.getenv("HF_TOKEN") | |
if not hf_token: | |
print("β No HF token β skipping Hub push.") | |
return False | |
login(token=hf_token) | |
fp = os.path.join(DATA_DIR, DATA_FILE) | |
if not os.path.exists(fp): | |
print("β Local feedback file missing; nothing to push.") | |
return False | |
# local rows β Dataset | |
with open(fp, encoding="utf-8") as f: | |
local_ds = Dataset.from_list([json.loads(l) for l in f]) | |
# try to pull remote | |
try: | |
remote_ds = load_dataset( | |
DATASET_REPO, | |
split="train", | |
token=hf_token, | |
download_mode=DownloadMode.FORCE_REDOWNLOAD, | |
) | |
merged = concatenate_datasets([remote_ds, local_ds]).unique("id") | |
except FileNotFoundError: | |
# repo exists but empty | |
merged = local_ds | |
except Exception: | |
# repo may not exist yet β create & start fresh | |
HfApi(token=hf_token).create_repo( | |
repo_id=DATASET_REPO, repo_type="dataset", private=True | |
) | |
merged = local_ds | |
merged.push_to_hub( | |
DATASET_REPO, | |
private=True, | |
commit_message=f"Add {len(local_ds)} new feedback entries", | |
) | |
print( | |
f"β Pushed {len(local_ds)} rows; dataset now has {len(merged)} total.") | |
# (optional) clear local file once synced | |
# os.remove(fp) | |
return True | |
# ββββββββββββββββββββββββββββ chat backend βββββββββββββββββββββββββββββββββ | |
def generate_response(message: str, | |
history: list[dict[str, str]], | |
temperature: float, | |
top_p: float): | |
"""Streaming generator used by the Gradio ChatInterface.""" | |
# 1) add user message to history | |
history.append({"role": "user", "content": message}) | |
# 2) build model input via chat template | |
prompt = tokenizer.apply_chat_template(history, tokenize=False, | |
add_generation_prompt=True) | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, | |
skip_special_tokens=True) | |
gen_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=1024, | |
temperature=float(temperature), | |
top_p=float(top_p), | |
do_sample=True, | |
streamer=streamer, | |
) | |
# run on a worker thread so we can yield tokens live | |
Thread(target=model.generate, kwargs=gen_kwargs).start() | |
partial = "" | |
for token in streamer: | |
partial += token | |
yield partial, history # 1st out = msg, 2nd out = state | |
# once finished, commit assistant reply to history | |
history.append({"role": "assistant", "content": partial}) | |
yield partial, history | |
# ββββββββββββββββββββββββββββ feedback handler βββββββββββββββββββββββββββββ | |
def submit_feedback(conversation_state: list[dict[str, str]], | |
satisfaction: str, | |
feedback_text: str): | |
"""Callback for the *Submit Research Feedback* button.""" | |
save_feedback_locally(conversation_state, satisfaction, feedback_text) | |
pushed = push_feedback_to_hub() | |
if pushed: | |
return "β Thanks! Your feedback is safely stored." | |
return "β οΈ Saved locally; Hub push failed. Check server logs." | |
# ββββββββββββββββββββββββββββ UI layout ββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="Marin-8B Research Preview") as demo: | |
# state object to surface chat history to the feedback form | |
conversation_state = gr.State([]) | |
with gr.Row(): | |
# βββ Chat column βββ | |
with gr.Column(scale=3): | |
chatbot = gr.ChatInterface( | |
fn=generate_response, | |
additional_inputs=[conversation_state, # keeps state in sync | |
gr.Slider(0.1, 2.0, value=0.7, step=0.1, | |
label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.9, step=0.05, | |
label="Top-P")], | |
additional_outputs=[conversation_state], | |
type="messages", | |
) | |
# βββ Sidebar column βββ | |
with gr.Column(scale=1): | |
report_btn = gr.Button("Share Feedback", variant="primary") | |
# feedback modal (hidden by default) | |
with Modal(visible=False) as fb_modal: | |
gr.Markdown("## Research Preview Feedback") | |
gr.Markdown("We appreciate your help improving Marin-8B! β¨") | |
sat_radio = gr.Radio([ | |
"Very satisfied", "Satisfied", "Neutral", | |
"Unsatisfied", "Very unsatisfied"], | |
label="Overall experience", | |
value="Neutral", | |
) | |
fb_text = gr.Textbox(lines=6, label="Comments / suggestions") | |
send_btn = gr.Button("Submit", variant="primary") | |
status_box = gr.Textbox(label="Status", interactive=False) | |
# interactions | |
# open the modal without custom JS β use Modal update | |
report_btn.click(lambda: Modal(visible=True), None, fb_modal) | |
send_btn.click( | |
submit_feedback, | |
inputs=[conversation_state, sat_radio, fb_text], | |
outputs=status_box, | |
) | |
# ββββββββββββββββββββββββββββ run! βββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
demo.launch() | |