theostos commited on
Commit
5e2217f
·
1 Parent(s): 0ceabea

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +163 -0
  2. requirements.txt +5 -0
  3. test.json +3 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ import json
6
+ import random # <<< NEW
7
+ from threading import Thread
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
9
+
10
+ # >>>> CHANGE THIS <<<<
11
+ MODEL_ID = os.getenv("MODEL_ID", "theostos/babel-ssreflect-fp8")
12
+ TEST_JSON_PATH = os.getenv("TEST_JSON_PATH", "test.json")
13
+
14
+ INSTRUCTION_TEMPLATE = "You are given a proof term:\n\n{term}\n\nYour task is to derive a sequence of SSReflect tactics that corresponds to this term.\n\nWhen you work through the problem, write down your reasoning in detail inside <think> ... </think> tags. This reasoning should reflect your natural thought process as you explore the structure of the term and figure out what tactics to apply. You should consider different possible approaches, reflect on why some might or might not work, and gradually converge on a tactic choice.\n\nAfter each reasoning block, provide the next (group of) tactic(s) enclosed in:\n\n\\box{{\n <tactic>\n}}\n\nSome dependencies that could be helpful:\n\n{dependencies}"
15
+
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)
19
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
20
+ tokenizer.pad_token_id = tokenizer.eos_token_id
21
+
22
+ _model = None
23
+ def load_model():
24
+ global _model
25
+ if _model is None:
26
+ _model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ token=HF_TOKEN,
29
+ device_map="auto",
30
+ dtype="auto",
31
+ trust_remote_code=True
32
+ )
33
+ return _model
34
+
35
+ def build_messages(term: str, deps: str):
36
+ instr = INSTRUCTION_TEMPLATE.format(term=term, dependencies=deps)
37
+ return [{"role": "user", "content": instr}]
38
+
39
+ def load_test_examples(path=TEST_JSON_PATH):
40
+ """
41
+ Expects a JSON list of dicts with keys:
42
+ - 'rocq'
43
+ - 'term'
44
+ - 'notations'
45
+ - 'constants'
46
+ """
47
+ try:
48
+ with open(path, "r", encoding="utf-8") as f:
49
+ data = json.load(f)
50
+ if not isinstance(data, list):
51
+ raise ValueError("Test set JSON must be a list of objects.")
52
+
53
+ for entry in data:
54
+ entry['dependencies'] = "\n".join(entry['notations']) + "\n".join(entry['constants'])
55
+ entry['initial_proof'] = "\n".join(entry['steps'])
56
+ print(f"[info] Loaded {len(data)} test examples from {path}")
57
+ return data
58
+ except Exception as e:
59
+ print(f"[warn] Could not load test set {path}: {e}")
60
+ return []
61
+
62
+ TEST_EXAMPLES = load_test_examples()
63
+
64
+ def _duration(term, deps, temperature, top_p, max_new_tokens):
65
+ return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))
66
+
67
+ @spaces.GPU(duration=_duration)
68
+ def generate(term, deps, temperature, top_p, max_new_tokens):
69
+ model = load_model()
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+
72
+ messages = build_messages(term, deps)
73
+
74
+ prompt_text = tokenizer.apply_chat_template(
75
+ messages,
76
+ tokenize=False,
77
+ add_generation_prompt=True
78
+ )
79
+ inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
80
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
81
+ gen_kwargs = dict(
82
+ inputs=inputs,
83
+ max_new_tokens=int(max_new_tokens),
84
+ temperature=float(temperature),
85
+ top_p=float(top_p),
86
+ do_sample=True,
87
+ streamer=streamer,
88
+ pad_token_id=tokenizer.pad_token_id,
89
+ eos_token_id=tokenizer.eos_token_id,
90
+ )
91
+
92
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
93
+ thread.start()
94
+
95
+ out = ""
96
+ for token in streamer: # stream tokens to UI
97
+ out += token
98
+ yield f"```rocq\n{out}\n```"
99
+
100
+
101
+ def _sample_test_example():
102
+ if not TEST_EXAMPLES:
103
+ return "", "", "No test examples loaded. Set TEST_JSON_PATH or add test.json at repo root."
104
+ ex = random.choice(TEST_EXAMPLES)
105
+ return ex['term'], ex['dependencies'], ex['initial_proof']
106
+
107
+ # NEW: hot-reload the test set
108
+ def _reload_test_set():
109
+ global TEST_EXAMPLES
110
+ TEST_EXAMPLES = load_test_examples()
111
+ return gr.update(value=f"Reloaded {len(TEST_EXAMPLES)} test examples from {TEST_JSON_PATH}.")
112
+
113
+ with gr.Blocks(title="Proof translator (ZeroGPU, FP8)") as demo:
114
+ gr.Markdown(
115
+ "# Vanilla Rocq to SSReflect proof translator\n"
116
+ "Write a proof term, "
117
+ "then write dependencies appearing in the source proof.\n\n"
118
+ "You can also use **🎲 Draw test example** to pull a sample from the test set."
119
+ )
120
+
121
+ with gr.Row():
122
+ sample_btn = gr.Button("🎲 Draw test example", variant="secondary")
123
+ reload_test_btn = gr.Button("Reload test set", variant="secondary")
124
+
125
+ with gr.Row():
126
+ term_box = gr.Code(
127
+ label="Pretty-printed proof term",
128
+ language=None,
129
+ interactive=True,
130
+ lines=18,
131
+ )
132
+ dep_box = gr.Code(
133
+ label="Dependencies contain in the proof term",
134
+ language=None,
135
+ interactive=True,
136
+ lines=18,
137
+ )
138
+
139
+ with gr.Row():
140
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
141
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
142
+ max_new = gr.Slider(256, 8192, value=4096, step=128, label="max_new_tokens")
143
+
144
+ # Output panels: model vs baseline/truth
145
+ with gr.Row():
146
+ out = gr.Markdown(label="Generated proof")
147
+ baseline = gr.Code(label="Source proof", language=None)
148
+ btn = gr.Button("Translate", variant="primary")
149
+
150
+ test_notice = gr.Markdown("")
151
+ sample_btn.click(_sample_test_example, inputs=None, outputs=[term_box, dep_box, baseline])
152
+ reload_test_btn.click(_reload_test_set, inputs=None, outputs=test_notice)
153
+ btn.click(
154
+ generate,
155
+ inputs=[term_box, dep_box, temperature, top_p, max_new],
156
+ outputs=out,
157
+ concurrency_limit=1,
158
+ )
159
+
160
+ demo.queue(max_size=20, default_concurrency_limit=1)
161
+
162
+ if __name__ == "__main__":
163
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ transformers>=4.57.1
3
+ accelerate>=1.10
4
+ gradio>=4.44
5
+ spaces>=0.42
test.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10de7428eb265987667a909e16cd4cc58f1e983dbb3099d1cb9b8d5be1c829c6
3
+ size 1922622