Spaces:
Running
Running
File size: 4,849 Bytes
b7304c4 c08680c 2dd44a8 e18ee0c c08680c 23814a9 421d392 b7304c4 e4b0b00 2dd44a8 e18ee0c 2dd44a8 7f977c5 2dd44a8 7f977c5 2dd44a8 7f977c5 2dd44a8 7f977c5 2dd44a8 e18ee0c 7f977c5 e18ee0c b7304c4 7f977c5 e18ee0c 7f977c5 2dd44a8 e18ee0c 7f977c5 23814a9 7f977c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python3
import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
def _rj(p, d):
try:
return json.load(open(p, encoding="utf-8"))
except:
return d
def _aw(p, o):
t = p + ".tmp"
open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
os.replace(t, p)
prompts = _rj(PROMPTS_PATH, [])
if not prompts:
raise Exception("No prompts found in full_prompts.json")
tok = os.environ.get("HF_READ_TOKEN")
log.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
token=tok
)
model.to("cpu"); model.eval()
log.info("Model is ready.")
lock = threading.Lock()
def _init():
state = _rj(STATE_PATH, {})
if not state or state.get("finished"):
idx = random.randrange(len(prompts))
state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
_aw(STATE_PATH, state)
return state
def _es(start_time):
elapsed = int(time.time() - start_time)
h, rem = divmod(elapsed, 3600)
m, s = divmod(rem, 60)
return f"{h}h {m}m {s}s"
def _loop():
while True:
with lock:
st = _init()
if st["finished"]:
time.sleep(SECS_PER_TOKEN)
continue
context = st["p"] + st["g"]
ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
with torch.no_grad():
out = model.generate(
ids,
max_new_tokens=1,
do_sample=True,
temperature=TEMP,
top_p=TOP_P
)
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
st["g"] += next_token
st["c"] += 1
if st["c"] >= TOKENS_PER_PROMPT:
st["finished"] = True
_aw(STATE_PATH, st)
time.sleep(SECS_PER_TOKEN)
threading.Thread(target=_loop, daemon=True).start()
def _fetch():
state = _rj(STATE_PATH, {})
if not state:
return "...", "", "0h 0m 0s"
return state["p"], state["g"], _es(state["t"])
def _submit_prediction(detailed, summary):
det = detailed.strip()
if not det:
return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
prompt_text, oracle_resp, elapsed = _fetch()
record = {
"ts": datetime.now(timezone.utc).isoformat(),
"prompt": prompt_text,
"time": elapsed,
"resp": oracle_resp,
"prediction": det,
"summary": summary.strip()
}
with lock:
open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")
with gr.Blocks(theme="darkdefault") as demo:
gr.Markdown(
"# What Comes Next\n"
"Enter what you think will come next in the text.\n"
"Provide a detailed continuation and optionally a brief summary for context."
)
prompt_md = gr.Markdown()
oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
time_info = gr.Textbox(interactive=False, label="Elapsed Time")
with gr.Row():
prompt_md, oracle_output, time_info
detailed = gr.Textbox(
label="Your Detailed Prediction",
placeholder="Enter the full text continuation you expect...",
lines=3
)
summary = gr.Textbox(
label="Prediction Summary (Optional)",
placeholder="Optionally, summarize your prediction in a few words...",
lines=2
)
status = gr.Textbox(interactive=False, label="Status")
submit_btn = gr.Button("Submit Prediction")
refresh_btn = gr.Button("Refresh Oracle")
demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
submit_btn.click(
_submit_prediction,
inputs=[detailed, summary],
outputs=[status, detailed, summary]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|