kaustubhg73 commited on
Commit
5b695bd
·
1 Parent(s): 267d60a
Files changed (12) hide show
  1. README.md +165 -111
  2. client.py +13 -3
  3. env/adapt_env.py +301 -67
  4. env/generator.py +990 -131
  5. env/test_cases.py +1 -3
  6. models.py +11 -3
  7. openenv.yaml +3 -31
  8. scripts/test_env.py +42 -15
  9. server/app.py +71 -15
  10. training/plot_results.py +139 -0
  11. training/train_grpo.py +436 -60
  12. verifier/metrics.py +55 -16
README.md CHANGED
@@ -8,196 +8,250 @@ tags:
8
  - openenv
9
  - reinforcement-learning
10
  - code-generation
 
11
  ---
12
 
13
- # ADAPT DSA Tutor OpenEnv
14
 
15
- ADAPT, the Adversarial DSA Tutor, is an OpenEnv-compliant RLVR environment for training code-generation agents on small DSA tasks. The agent receives a problem prompt, examples, and visible tests, then submits Python code. The environment runs the code against visible and hidden tests and returns reward, pass-rate metrics, execution status, and feedback.
16
 
17
- This repo includes the environment, verifier helpers, a baseline inference runner, and a GRPO training entrypoint so the full submission flow can be exercised from one codebase.
18
 
19
- ## Why This Environment
20
 
21
- The hackathon asks for OpenEnv environments that can improve LLM behavior through verifiable interaction. ADAPT targets a simple but useful skill loop:
 
 
 
 
 
 
 
22
 
23
  ```text
24
- agent writes code -> environment executes it -> hidden tests and reward signals score it -> trainer improves the agent
 
 
 
 
 
25
  ```
26
 
27
- The differentiator is curriculum-ready DSA practice: each episode carries a problem id and difficulty tier so training can track per-tier success instead of only aggregate reward.
 
 
 
 
 
 
 
28
 
29
- ## OpenEnv Interface
30
 
31
- The environment uses the latest OpenEnv API shape:
 
 
 
 
32
 
33
- - `AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState])`
34
- - `reset()` returns a typed observation.
35
- - `step(action)` accepts an `AdaptAction` with a Python `code` string.
36
- - `state` exposes episode id, step count, current problem id, difficulty, and recent metrics.
37
 
38
- `openenv.yaml` points to:
39
 
40
- ```yaml
41
- app: server.app:app
42
- port: 7860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ```
44
 
45
- ## Action
 
 
 
 
46
 
47
  ```python
48
- {
49
- "code": "n = int(input())\nprint(n * 2)"
50
- }
51
  ```
52
 
53
- ## Observation
54
 
55
- Reset and step observations include:
 
 
56
 
57
- - problem statement
58
- - input format
59
- - constraints
60
- - examples
61
- - visible tests
62
- - problem id
63
- - difficulty tier
64
- - feedback
65
- - pass rate, visible pass rate, and hidden pass rate
66
- - syntax/runtime/timeout status
67
- - reward components
68
 
69
- Hidden test inputs and expected outputs are never returned in observations.
 
 
70
 
71
- ## Reward
72
 
73
- Reward is clipped to `[0.0, 1.0]` and combines multiple environment-level signals:
 
 
 
74
 
75
- - correctness from visible and hidden pass rate
76
- - syntax validity
77
- - clean execution
78
- - output format compliance
79
- - timeout penalty
80
- - runtime error penalty
81
- - static safety rejection for dangerous imports such as `os`, `subprocess`, `socket`, `pathlib`, and `shutil`
82
 
83
- If `verifier.verifier.verify(code, test_cases)` exists, the environment can use it as an optional reward augmentation. If the verifier is absent, the environment still works using executor-derived reward.
84
 
85
- ## Local Setup
 
 
86
 
87
- Use Python `3.10+`.
88
 
89
- ```powershell
90
- cd C:\Users\kaust\PycharmProjects\meta-rl-dsa-solver
91
- python -m venv .venv
92
- .\.venv\Scripts\pip install -e .
93
- ```
94
 
95
- For this local machine, the existing checked-out OpenEnv repo can also be used during development:
96
 
97
- ```powershell
98
- $env:PYTHONPATH="C:\Users\kaust\PycharmProjects\OpenEnv\src;$PWD"
99
- ```
100
 
101
- ## Smoke Tests
 
 
102
 
103
- Run the local smoke test:
104
 
105
- ```powershell
106
- python test.py
107
  ```
108
 
109
- Check syntax:
110
 
111
- ```powershell
112
- python -m py_compile models.py env\adapt_env.py env\executor.py env\test_cases.py server\app.py
113
  ```
114
 
115
- Start the OpenEnv server:
116
 
117
- ```powershell
118
- uvicorn server.app:app --host 0.0.0.0 --port 7860
119
- ```
 
 
 
 
 
 
 
 
 
 
120
 
121
- Useful endpoints:
122
 
123
- - `GET /`
124
- - `GET /health`
125
- - `GET /metadata`
126
- - `GET /tasks`
127
- - `GET /schema`
128
- - `POST /reset`
129
- - `POST /step`
130
- - `GET /state`
131
- - `POST /mcp`
132
 
133
- Example step request:
 
 
 
 
 
 
134
 
135
  ```powershell
136
- curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"action\":{\"code\":\"n=int(input())\nprint(n*2)\"}}"
137
  ```
138
 
139
- You can also send the raw action body:
140
 
141
  ```powershell
142
- curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d "{\"code\":\"n=int(input())\nprint(n*2)\"}"
143
  ```
144
 
145
- Validate with OpenEnv once dependencies are installed:
146
 
147
  ```powershell
148
- openenv validate .
 
 
149
  ```
150
 
151
- Run the verifier smoke test:
 
 
152
 
153
  ```powershell
154
- python scripts\test_verifier.py
 
 
155
  ```
156
 
157
- Run the environment smoke test:
158
 
159
  ```powershell
160
- python scripts\test_env.py
161
  ```
162
 
163
- Run the baseline model loop:
164
 
165
  ```powershell
166
- $env:HF_TOKEN="..."
167
- $env:API_BASE_URL="https://router.huggingface.co/v1"
168
- $env:MODEL_NAME="openai/gpt-oss-120b"
169
- python inference.py
170
  ```
171
 
172
- Run GRPO training:
173
 
174
  ```powershell
175
- python training\train_grpo.py --output-dir outputs_v2 --bf16
176
  ```
177
 
178
- ## Hugging Face Spaces
179
 
180
- This repo is Docker Space ready:
181
 
182
  ```powershell
183
  openenv push --repo-id <your-hf-username>/adapt-dsa-tutor
184
  ```
185
 
186
- Before final submission, add:
187
-
188
- - live Hugging Face Space link
189
- - training reward/loss plots from Disha's run
190
- - before/after code example showing a problem the model failed before training and solved after training
191
- - mini-blog or short video link
192
-
193
- ## Current Problem Bank
194
 
195
- The environment includes a lightweight curated bank:
 
 
 
 
 
 
196
 
197
- - `easy_double`
198
- - `easy_sum_two`
199
- - `medium_maximum`
200
- - `medium_count_even`
201
- - `hard_reverse_words`
202
 
203
- This is intentionally small for submission-minimum stability. Later work can expand it to 30-50 tiered problems without changing the OpenEnv API.
 
 
 
 
8
  - openenv
9
  - reinforcement-learning
10
  - code-generation
11
+ - llm-training
12
  ---
13
 
14
+ # ADAPT: Adversarial DSA Programming Tutor
15
 
16
+ LLMs are getting better at one-shot code generation, but they still struggle with the thing real engineers do all day: read feedback, debug, and repair. ADAPT closes that gap by turning algorithm practice into a self-repair RL environment where the model must improve over multiple attempts instead of guessing once.
17
 
18
+ ## Why ADAPT exists
19
 
20
+ Most code-generation benchmarks test whether a model can land the answer immediately. They do not test whether the model can recover from partial failure, use examples productively, or adapt as the task distribution changes.
21
 
22
+ ADAPT is built to stress exactly those capabilities:
23
+
24
+ - adaptive difficulty across easy, medium, and hard DSA families
25
+ - visible examples plus hidden evaluation tests
26
+ - multi-step repair with feedback between attempts
27
+ - reward-aware problem generation that shifts toward the most educational families
28
+
29
+ ## Architecture
30
 
31
  ```text
32
+ +------------+ +-----------+ +----------+ +-----------+
33
+ | Generator | --> | Problem | --> | Solver | --> | Execution |
34
+ +------------+ +-----------+ +----------+ +-----------+
35
+ ^ |
36
+ | v
37
+ +------------- Curriculum <- Reward <- Verification -----+
38
  ```
39
 
40
+ ## What the agent sees, does, and gets rewarded for
41
+
42
+ The agent sees a plain-English programming problem, the stdin format, constraints, and two worked examples. It writes Python code that reads from stdin and prints to stdout.
43
+
44
+ The environment executes that code on 10 tests per problem:
45
+
46
+ - 2 visible tests shown as examples
47
+ - 8 hidden tests used for the real pass-rate reward
48
 
49
+ After each attempt, the environment returns:
50
 
51
+ - hidden pass rate
52
+ - visible pass rate
53
+ - execution status such as `completed`, `wrong_answer`, `runtime_error`, or `timeout`
54
+ - a compact list of which tests failed
55
+ - enough context to try again on the same problem
56
 
57
+ ## Multi-step repair loop
 
 
 
58
 
59
+ Each episode allows up to 3 attempts on the same problem.
60
 
61
+ 1. Attempt 1: the agent submits a first solution.
62
+ 2. Feedback: ADAPT reports the current execution status, hidden pass rate, visible pass rate, and which visible/hidden tests failed.
63
+ 3. Attempt 2 or 3: the agent repairs its code using that feedback.
64
+ 4. The episode ends early if all hidden tests pass.
65
+
66
+ Concrete example:
67
+
68
+ ```text
69
+ Problem family: running_total
70
+
71
+ Attempt 1 code:
72
+ print(sum(nums))
73
+
74
+ Feedback:
75
+ Attempt 1/3
76
+ Previous attempt status: ready
77
+ Current execution status: wrong_answer
78
+ Hidden pass rate: 0.25
79
+ Visible pass rate: 0.50
80
+ Failed tests:
81
+ - Visible test #2: wrong_answer (expected=5 3 10, got=10)
82
+ - Hidden test #1: wrong_answer
83
+ - Hidden test #4: wrong_answer
84
+
85
+ Attempt 2 code:
86
+ running = 0
87
+ for x in nums:
88
+ running += x
89
+ out.append(str(running))
90
+ print(" ".join(out))
91
  ```
92
 
93
+ That repair loop is the core novelty of ADAPT: the model is rewarded for debugging, not just for lucky first drafts.
94
+
95
+ ## Reward function
96
+
97
+ ADAPT uses a clean reward signal driven by hidden correctness:
98
 
99
  ```python
100
+ reward = hidden_pass_rate * step_discount
 
 
101
  ```
102
 
103
+ Where:
104
 
105
+ - `step_discount = 1.00` on attempt 1
106
+ - `step_discount = 0.85` on attempt 2
107
+ - `step_discount = 0.70` on attempt 3
108
 
109
+ Additional shaping for the repair loop:
 
 
 
 
 
 
 
 
 
 
110
 
111
+ - if a failed non-terminal attempt improves hidden pass rate, reward = `0.1 * delta_pass_rate`
112
+ - if the final attempt still fails, reward = `0.0`
113
+ - timeouts and syntax errors always get `0.0`
114
 
115
+ Examples:
116
 
117
+ - attempt 1 solves all 8 hidden tests: reward = `1.0`
118
+ - attempt 2 solves all 8 hidden tests: reward = `0.85`
119
+ - attempt 1 improves from `0.25` to `0.50` hidden pass rate on a retry trajectory: reward = `0.025`
120
+ - attempt 3 still fails: reward = `0.0`
121
 
122
+ ## Problem families
 
 
 
 
 
 
123
 
124
+ ADAPT now covers 20 algorithmic families instead of a tiny fixed bank:
125
 
126
+ - Easy: `sum_even_numbers`, `range_span`, `count_vowels`, `max_consecutive_ones`, `fizzbuzz_variant`, `running_total`
127
+ - Medium: `count_local_peaks`, `longest_non_decreasing_run`, `two_sum_count`, `max_subarray_sum`, `group_anagrams_count`, `balanced_brackets`, `matrix_diagonal_sum`
128
+ - Hard: `smallest_most_frequent`, `reverse_words`, `longest_common_subsequence`, `word_ladder_steps`, `merge_intervals`, `min_coins`, `rotate_matrix_90`
129
 
130
+ Every family has:
131
 
132
+ - its own randomized case generator
133
+ - 2 visible example tests
134
+ - 8 hidden evaluation tests
135
+ - a reference solver that auto-generates expected outputs
 
136
 
137
+ ## Self-improving curriculum
138
 
139
+ ADAPT uses one curriculum authority in training: the `CurriculumManager` inside `training/train_grpo.py`.
 
 
140
 
141
+ - promote threshold: `0.70`
142
+ - demote threshold: `0.30`
143
+ - moving-average window: `10` episodes
144
 
145
+ On top of that, the generator tracks `family_productivity`, an EMA of how educational each family is:
146
 
147
+ ```text
148
+ family_productivity[family] = 0.9 * old + 0.1 * generator_reward
149
  ```
150
 
151
+ Families that produce pass rates near the learning sweet spot, around `0.5`, become more likely to be sampled via a softmax distribution. This creates a closed loop:
152
 
153
+ ```text
154
+ productive families -> more samples -> better learning signal -> updated family productivity
155
  ```
156
 
157
+ That makes ADAPT more than a static benchmark. The environment actively searches for the problems that teach the model the most.
158
 
159
+ ## Results
160
+
161
+ [INSERT: reward curve plot]
162
+
163
+ [INSERT: baseline vs trained table]
164
+
165
+ Recommended artifacts to include here:
166
+
167
+ - reward curve from `training/reward_curve.csv`
168
+ - `reward_curve.png`
169
+ - `pass_rate_by_difficulty.png`
170
+ - `family_productivity.png`
171
+ - one before/after repair example from baseline vs trained evaluation
172
 
173
+ ## How to run
174
 
175
+ ### 1. Install dependencies
 
 
 
 
 
 
 
 
176
 
177
+ ```powershell
178
+ cd C:\Users\kaust\PycharmProjects\meta-rl-dsa-solver
179
+ python -m venv .venv
180
+ .\.venv\Scripts\pip install -e .
181
+ ```
182
+
183
+ For training and plotting, also install your training extras:
184
 
185
  ```powershell
186
+ .\.venv\Scripts\pip install trl unsloth matplotlib wandb
187
  ```
188
 
189
+ ### 2. Start the OpenEnv server
190
 
191
  ```powershell
192
+ python server\app.py
193
  ```
194
 
195
+ ### 3. Reset an environment session
196
 
197
  ```powershell
198
+ curl -X POST http://localhost:7860/reset ^
199
+ -H "Content-Type: application/json" ^
200
+ -d "{\"difficulty\":\"easy\"}"
201
  ```
202
 
203
+ The response includes a `session_id`. Reuse it for `step` and `state`.
204
+
205
+ ### 4. Submit code to `/step`
206
 
207
  ```powershell
208
+ curl -X POST http://localhost:7860/step ^
209
+ -H "Content-Type: application/json" ^
210
+ -d "{\"session_id\":\"<SESSION_ID>\",\"code\":\"n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))\"}"
211
  ```
212
 
213
+ ### 5. Inspect current state
214
 
215
  ```powershell
216
+ curl "http://localhost:7860/state?session_id=<SESSION_ID>"
217
  ```
218
 
219
+ ### 6. Run training
220
 
221
  ```powershell
222
+ python training\train_grpo.py ^
223
+ --generator-mode reward_aware ^
224
+ --baseline-eval ^
225
+ --output-dir outputs_v3
226
  ```
227
 
228
+ ### 7. Plot the training curves
229
 
230
  ```powershell
231
+ python training\plot_results.py outputs_v3\reward_curve.csv
232
  ```
233
 
234
+ ## Hugging Face Space
235
 
236
+ This repo is designed to be hosted as an OpenEnv FastAPI Space.
237
 
238
  ```powershell
239
  openenv push --repo-id <your-hf-username>/adapt-dsa-tutor
240
  ```
241
 
242
+ ## Submission checklist
 
 
 
 
 
 
 
243
 
244
+ - OpenEnv environment with `Environment`, `reset`, `step`, and `state`
245
+ - valid `openenv.yaml`
246
+ - Hugging Face Space deployment
247
+ - GRPO training script with Unsloth + TRL
248
+ - reward and pass-rate plots from a real run
249
+ - baseline vs trained evaluation summary
250
+ - Colab notebook link for reproducibility
251
 
252
+ ## Links
 
 
 
 
253
 
254
+ - HuggingFace Space URL: [HuggingFace Space URL]
255
+ - Colab Training Notebook: [Colab Training Notebook]
256
+ - HF Blog Post: [HF Blog Post]
257
+ - YouTube Demo: [YouTube Demo]
client.py CHANGED
@@ -11,6 +11,7 @@ class AdaptEnvClient:
11
  def __init__(self, base_url: str = "http://localhost:7860") -> None:
12
  self.base_url = base_url.rstrip("/")
13
  self._client = httpx.Client(base_url=self.base_url, timeout=30.0)
 
14
 
15
  def close(self) -> None:
16
  self._client.close()
@@ -18,15 +19,24 @@ class AdaptEnvClient:
18
  def reset(self, **params: Any) -> dict[str, Any]:
19
  response = self._client.post("/reset", json=params)
20
  response.raise_for_status()
21
- return response.json()
 
 
22
 
23
  def step(self, code: str) -> dict[str, Any]:
24
- response = self._client.post("/step", json=AdaptAction(code=code).model_dump())
 
 
 
 
 
25
  response.raise_for_status()
26
  return response.json()
27
 
28
  def state(self) -> dict[str, Any]:
29
- response = self._client.get("/state")
 
 
30
  response.raise_for_status()
31
  return response.json()
32
 
 
11
  def __init__(self, base_url: str = "http://localhost:7860") -> None:
12
  self.base_url = base_url.rstrip("/")
13
  self._client = httpx.Client(base_url=self.base_url, timeout=30.0)
14
+ self.session_id: str | None = None
15
 
16
  def close(self) -> None:
17
  self._client.close()
 
19
  def reset(self, **params: Any) -> dict[str, Any]:
20
  response = self._client.post("/reset", json=params)
21
  response.raise_for_status()
22
+ payload = response.json()
23
+ self.session_id = payload.get("session_id")
24
+ return payload
25
 
26
  def step(self, code: str) -> dict[str, Any]:
27
+ if not self.session_id:
28
+ raise RuntimeError("Call reset() before step() so the client has a session_id.")
29
+ response = self._client.post(
30
+ "/step",
31
+ json=AdaptAction(session_id=self.session_id, code=code).model_dump(),
32
+ )
33
  response.raise_for_status()
34
  return response.json()
35
 
36
  def state(self) -> dict[str, Any]:
37
+ if not self.session_id:
38
+ raise RuntimeError("Call reset() before state() so the client has a session_id.")
39
+ response = self._client.get("/state", params={"session_id": self.session_id})
40
  response.raise_for_status()
41
  return response.json()
42
 
env/adapt_env.py CHANGED
@@ -6,6 +6,7 @@ from uuid import uuid4
6
 
7
  from env.generator import DIFFICULTY_LABELS, GeneratorAgent, generator_reward, validate_problem
8
  from models import AdaptAction, AdaptObservation, AdaptState
 
9
 
10
  try:
11
  from openenv.core.env_server.interfaces import Environment
@@ -22,6 +23,7 @@ except ImportError:
22
 
23
 
24
  FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
 
25
 
26
 
27
  class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
@@ -31,14 +33,16 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
31
  self,
32
  generator: GeneratorAgent | None = None,
33
  generator_mode: str = "heuristic",
 
34
  ) -> None:
35
  super().__init__()
36
  self.generator = generator or GeneratorAgent()
37
  self.generator_mode = generator_mode
 
38
  self.problem: dict[str, Any] = {}
39
- self.test_cases: list[dict[str, str]] = []
40
  self.last_results: list[dict[str, Any]] = []
41
- self.max_history = 20
42
  self.min_difficulty = 1
43
  self.max_difficulty = 3
44
  self.difficulty: int = 1
@@ -49,7 +53,17 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
49
  "problem_signatures": [],
50
  "episode_index": 0,
51
  }
52
- self._state = AdaptState(episode_id=str(uuid4()), step_count=0, generator_mode=self.generator_mode)
 
 
 
 
 
 
 
 
 
 
53
 
54
  def reset(
55
  self,
@@ -59,35 +73,52 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
59
  difficulty: str | None = None,
60
  generated_problem: dict[str, Any] | None = None,
61
  generator_mode: str | None = None,
 
 
62
  **_: Any,
63
  ) -> AdaptObservation:
64
  del seed
 
 
 
65
  if generator_mode is not None:
66
  self.generator_mode = generator_mode
67
  if difficulty is not None:
68
  self.difficulty = self._difficulty_to_tier(difficulty)
69
- elif self.history["recent_pass_rates"]:
70
- self.difficulty = self._recommend_next_difficulty()
 
 
71
 
72
  self.problem = self._load_problem(
73
  generated_problem=generated_problem,
74
  problem_id=problem_id,
 
75
  )
76
  self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]]
77
  self.last_results = []
 
 
 
78
  self._state = AdaptState(
 
79
  episode_id=episode_id or str(uuid4()),
80
  step_count=0,
81
  problem_id=self.problem["problem_id"],
82
  problem_type=self.problem.get("problem_type", ""),
83
  difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
84
  generator_mode=self.generator_mode,
 
85
  generated_problem=self._public_problem_view(),
 
86
  )
87
  return self._build_observation(
88
  reward=0.0,
89
  done=False,
90
- feedback="Submit Python code that reads stdin and prints the required answer.",
 
 
 
91
  execution_status="ready",
92
  )
93
 
@@ -98,59 +129,130 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
98
  **_: Any,
99
  ) -> AdaptObservation:
100
  del timeout_s
 
101
  if not self.problem:
102
- self.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  self._state.step_count += 1
 
 
 
 
105
  syntax_ok, syntax_error = self._check_syntax(action.code)
106
  if not syntax_ok:
 
107
  observation = self._build_observation(
108
  reward=0.0,
109
- done=True,
110
- feedback=f"Syntax error: {syntax_error}",
 
 
 
 
 
111
  syntax_valid=False,
112
  execution_status="syntax_error",
113
- reward_components={"correctness": 0.0, "format": 0.0},
 
 
 
 
114
  )
115
- self._finalize_episode(observation)
 
 
 
 
116
  return observation
117
 
118
  safety_ok, safety_error = self._check_safety(action.code)
119
  if not safety_ok:
 
120
  observation = self._build_observation(
121
  reward=0.0,
122
- done=True,
123
- feedback=safety_error,
 
 
 
 
 
124
  syntax_valid=True,
125
  execution_status="safety_violation",
126
- reward_components={"correctness": 0.0, "format": 0.0},
 
 
 
 
127
  )
128
- self._finalize_episode(observation)
 
 
 
 
129
  return observation
130
 
131
- reward, metadata = self._verify_submission(action.code)
132
  self.last_results = list(metadata.get("results", []))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  observation = self._build_observation(
134
  reward=reward,
135
- done=True,
136
- feedback=str(metadata.get("feedback", "Evaluation complete.")),
137
- pass_rate=float(metadata.get("pass_rate", 0.0)),
138
- visible_pass_rate=0.0,
139
- hidden_pass_rate=float(metadata.get("pass_rate", 0.0)),
140
  syntax_valid=True,
141
- execution_status=str(metadata.get("execution_status", "completed")),
142
  timeout_count=int(metadata.get("timeout_count", 0)),
143
  runtime_error_count=int(metadata.get("runtime_error_count", 0)),
144
  invalid_output_count=int(metadata.get("invalid_output_count", 0)),
145
  wrong_answer_count=int(metadata.get("wrong_answer_count", 0)),
146
  format_compliance=float(metadata.get("format_compliance", 0.0)),
147
- reward_components={
148
- key: round(float(value), 4)
149
- for key, value in dict(metadata.get("reward_components", {})).items()
150
- },
151
  generator_reward_signal=float(metadata.get("generator_reward", 0.0)),
152
  )
153
- self._finalize_episode(observation)
 
 
 
154
  return observation
155
 
156
  @property
@@ -177,23 +279,26 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
177
  ) -> AdaptObservation:
178
  public_problem = self._public_problem_view()
179
  return AdaptObservation(
 
180
  problem_id=self.problem.get("problem_id", ""),
181
  problem_type=self.problem.get("problem_type", ""),
182
  difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
 
 
183
  problem=public_problem.get("problem", ""),
184
  input_format=public_problem.get("input_format", ""),
185
  constraints=public_problem.get("constraints", ""),
186
  feedback=feedback,
187
- pass_rate=pass_rate,
188
- visible_pass_rate=visible_pass_rate,
189
- hidden_pass_rate=hidden_pass_rate,
190
  syntax_valid=syntax_valid,
191
  execution_status=execution_status,
192
  timeout_count=timeout_count,
193
  runtime_error_count=runtime_error_count,
194
  invalid_output_count=invalid_output_count,
195
  wrong_answer_count=wrong_answer_count,
196
- format_compliance=format_compliance,
197
  reward_components=reward_components or {},
198
  generator_reward_signal=round(float(generator_reward_signal), 4),
199
  reward=round(max(0.0, min(1.0, reward)), 4),
@@ -204,15 +309,22 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
204
  self,
205
  generated_problem: dict[str, Any] | None,
206
  problem_id: str | None,
 
207
  ) -> dict[str, Any]:
208
- candidate = generated_problem or self.generator.generate(
209
  self.difficulty,
210
  self.history,
211
  problem_id=problem_id,
 
212
  )
213
  if validate_problem(candidate):
214
  return candidate
215
- fallback = self.generator.generate(self.difficulty, self.history, problem_id=problem_id)
 
 
 
 
 
216
  if not validate_problem(fallback):
217
  raise ValueError("Generator produced an invalid problem twice in a row.")
218
  return fallback
@@ -221,53 +333,155 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
221
  try:
222
  from verifier.verifier import verify
223
  except ImportError as exc:
224
- return 0.0, {"feedback": f"Verifier unavailable: {exc}", "execution_status": "verifier_error"}
 
 
 
 
225
 
226
  try:
227
  reward, metadata = verify(code, self.test_cases)
228
  except Exception as exc:
229
- return 0.0, {"feedback": f"Verifier crashed: {exc}", "execution_status": "verifier_error"}
 
 
 
 
230
 
231
  metadata = dict(metadata or {})
232
  diversity_bonus = self._diversity_bonus(self.problem.get("problem_type", ""))
233
  validity_bonus = float(self.problem.get("validity_bonus", 0.0))
 
234
  metadata["generator_reward"] = generator_reward(
235
- float(metadata.get("pass_rate", 0.0)),
236
  diversity_bonus=diversity_bonus,
237
  validity_bonus=validity_bonus,
238
  )
239
  return float(reward), metadata
240
 
241
- def _finalize_episode(self, observation: AdaptObservation) -> None:
242
- self._update_history(observation.pass_rate, observation.generator_reward_signal)
243
- self._record_metrics(observation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- def _update_history(self, pass_rate: float, generator_signal: float) -> None:
246
- self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
247
- self.history["problem_types"].append(self.problem.get("problem_type", ""))
248
- self.history["problem_signatures"].append(self.problem.get("problem_id", ""))
249
- self.history["generator_rewards"].append(round(float(generator_signal), 4))
250
- self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
253
- values = self.history[key]
254
- if len(values) > self.max_history:
255
- del values[:-self.max_history]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  def _record_metrics(self, observation: AdaptObservation) -> None:
 
 
 
 
 
 
 
 
 
 
258
  self._state.last_reward = float(observation.reward or 0.0)
259
  self._state.last_pass_rate = observation.pass_rate
260
  self._state.last_feedback = observation.feedback
 
261
  self._state.generator_reward_signal = observation.generator_reward_signal
262
- self._state.history = {
263
- "recent_pass_rates": list(self.history["recent_pass_rates"]),
264
- "problem_types": list(self.history["problem_types"]),
265
- "generator_rewards": list(self.history["generator_rewards"]),
266
- }
267
  self._state.recent_metrics = {
268
  "difficulty_tier": self.difficulty,
269
  "difficulty_label": self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
270
- "history_size": len(self.history["recent_pass_rates"]),
271
  "pass_rate": observation.pass_rate,
272
  "execution_status": observation.execution_status,
273
  "timeout_count": observation.timeout_count,
@@ -278,27 +492,47 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
278
  "reward_components": dict(observation.reward_components),
279
  }
280
 
281
- def _recommend_next_difficulty(self) -> int:
282
- recent = [float(value) for value in self.history["recent_pass_rates"][-5:]]
283
- if not recent:
284
- return self.difficulty
285
- moving_average = sum(recent) / len(recent)
286
- if moving_average > 0.75:
287
- return min(self.max_difficulty, self.difficulty + 1)
288
- if moving_average < 0.25:
289
- return max(self.min_difficulty, self.difficulty - 1)
290
- return self.difficulty
 
 
 
 
 
291
 
292
  def _public_problem_view(self) -> dict[str, str]:
293
  visible = dict(self.problem.get("visible_problem", {}))
 
 
 
 
294
  return {
295
- "problem": visible.get("problem", self.problem.get("problem", "")),
296
  "input_format": visible.get("input_format", self.problem.get("input_format", "")),
297
  "constraints": visible.get("constraints", self.problem.get("constraints", "")),
298
  }
299
 
 
 
 
 
 
 
 
 
 
 
 
300
  def _diversity_bonus(self, problem_type: str) -> float:
301
- recent_types = list(self.history.get("problem_types", [])[-4:])
302
  if not recent_types:
303
  return 0.1
304
  if problem_type in recent_types:
 
6
 
7
  from env.generator import DIFFICULTY_LABELS, GeneratorAgent, generator_reward, validate_problem
8
  from models import AdaptAction, AdaptObservation, AdaptState
9
+ from verifier.metrics import compute_reward
10
 
11
  try:
12
  from openenv.core.env_server.interfaces import Environment
 
23
 
24
 
25
  FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
26
+ MAX_STEPS_PER_EPISODE = 3
27
 
28
 
29
  class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
 
33
  self,
34
  generator: GeneratorAgent | None = None,
35
  generator_mode: str = "heuristic",
36
+ session_id: str | None = None,
37
  ) -> None:
38
  super().__init__()
39
  self.generator = generator or GeneratorAgent()
40
  self.generator_mode = generator_mode
41
+ self.session_id = session_id or str(uuid4())
42
  self.problem: dict[str, Any] = {}
43
+ self.test_cases: list[dict[str, Any]] = []
44
  self.last_results: list[dict[str, Any]] = []
45
+ self.max_history = 50
46
  self.min_difficulty = 1
47
  self.max_difficulty = 3
48
  self.difficulty: int = 1
 
53
  "problem_signatures": [],
54
  "episode_index": 0,
55
  }
56
+ self.attempt_history: list[dict[str, Any]] = []
57
+ self.previous_execution_status = "ready"
58
+ self.episode_done = False
59
+ self._state = AdaptState(
60
+ session_id=self.session_id,
61
+ episode_id=str(uuid4()),
62
+ step_count=0,
63
+ generator_mode=self.generator_mode,
64
+ max_steps=MAX_STEPS_PER_EPISODE,
65
+ history={"attempts": []},
66
+ )
67
 
68
  def reset(
69
  self,
 
73
  difficulty: str | None = None,
74
  generated_problem: dict[str, Any] | None = None,
75
  generator_mode: str | None = None,
76
+ session_id: str | None = None,
77
+ family_weights: dict[str, float] | None = None,
78
  **_: Any,
79
  ) -> AdaptObservation:
80
  del seed
81
+
82
+ if session_id:
83
+ self.session_id = session_id
84
  if generator_mode is not None:
85
  self.generator_mode = generator_mode
86
  if difficulty is not None:
87
  self.difficulty = self._difficulty_to_tier(difficulty)
88
+ elif generated_problem is not None:
89
+ generated_label = str(generated_problem.get("difficulty_label", "")).strip().lower()
90
+ if generated_label:
91
+ self.difficulty = self._difficulty_to_tier(generated_label)
92
 
93
  self.problem = self._load_problem(
94
  generated_problem=generated_problem,
95
  problem_id=problem_id,
96
+ family_weights=family_weights,
97
  )
98
  self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]]
99
  self.last_results = []
100
+ self.attempt_history = []
101
+ self.previous_execution_status = "ready"
102
+ self.episode_done = False
103
  self._state = AdaptState(
104
+ session_id=self.session_id,
105
  episode_id=episode_id or str(uuid4()),
106
  step_count=0,
107
  problem_id=self.problem["problem_id"],
108
  problem_type=self.problem.get("problem_type", ""),
109
  difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
110
  generator_mode=self.generator_mode,
111
+ max_steps=MAX_STEPS_PER_EPISODE,
112
  generated_problem=self._public_problem_view(),
113
+ history={"attempts": []},
114
  )
115
  return self._build_observation(
116
  reward=0.0,
117
  done=False,
118
+ feedback=(
119
+ "You have up to 3 attempts. Submit Python code that reads stdin and prints the required answer. "
120
+ "Use the examples to infer the expected behavior."
121
+ ),
122
  execution_status="ready",
123
  )
124
 
 
129
  **_: Any,
130
  ) -> AdaptObservation:
131
  del timeout_s
132
+
133
  if not self.problem:
134
+ self.reset(session_id=action.session_id or self.session_id)
135
+
136
+ if self.episode_done:
137
+ return self._build_observation(
138
+ reward=float(self._state.last_reward or 0.0),
139
+ done=True,
140
+ feedback="This episode is finished. Call reset() to start a new problem.",
141
+ pass_rate=float(self._state.last_pass_rate or 0.0),
142
+ visible_pass_rate=float(self._state.recent_metrics.get("visible_pass_rate", 0.0)),
143
+ hidden_pass_rate=float(self._state.last_pass_rate or 0.0),
144
+ syntax_valid=self._state.last_execution_status != "syntax_error",
145
+ execution_status=self._state.last_execution_status or "completed",
146
+ timeout_count=int(self._state.recent_metrics.get("timeout_count", 0)),
147
+ runtime_error_count=int(self._state.recent_metrics.get("runtime_error_count", 0)),
148
+ invalid_output_count=int(self._state.recent_metrics.get("invalid_output_count", 0)),
149
+ wrong_answer_count=int(self._state.recent_metrics.get("wrong_answer_count", 0)),
150
+ format_compliance=float(self._state.recent_metrics.get("format_compliance", 0.0)),
151
+ reward_components=dict(self._state.recent_metrics.get("reward_components", {})),
152
+ generator_reward_signal=float(self._state.generator_reward_signal or 0.0),
153
+ )
154
 
155
  self._state.step_count += 1
156
+ attempt_number = self._state.step_count
157
+ previous_status = self.previous_execution_status
158
+ previous_pass_rate = float(self._state.last_pass_rate or 0.0)
159
+
160
  syntax_ok, syntax_error = self._check_syntax(action.code)
161
  if not syntax_ok:
162
+ done = attempt_number >= MAX_STEPS_PER_EPISODE
163
  observation = self._build_observation(
164
  reward=0.0,
165
+ done=done,
166
+ feedback=self._format_static_feedback(
167
+ attempt_number=attempt_number,
168
+ previous_status=previous_status,
169
+ execution_status="syntax_error",
170
+ details=f"Syntax error: {syntax_error}",
171
+ ),
172
  syntax_valid=False,
173
  execution_status="syntax_error",
174
+ reward_components={
175
+ "correctness": 0.0,
176
+ "step_discount": 1.0 if attempt_number == 1 else (0.85 if attempt_number == 2 else 0.70),
177
+ "progress_delta": 0.0,
178
+ },
179
  )
180
+ self.last_results = []
181
+ self.previous_execution_status = observation.execution_status
182
+ self._record_metrics(observation)
183
+ if done:
184
+ self._finalize_episode(observation)
185
  return observation
186
 
187
  safety_ok, safety_error = self._check_safety(action.code)
188
  if not safety_ok:
189
+ done = attempt_number >= MAX_STEPS_PER_EPISODE
190
  observation = self._build_observation(
191
  reward=0.0,
192
+ done=done,
193
+ feedback=self._format_static_feedback(
194
+ attempt_number=attempt_number,
195
+ previous_status=previous_status,
196
+ execution_status="safety_violation",
197
+ details=safety_error,
198
+ ),
199
  syntax_valid=True,
200
  execution_status="safety_violation",
201
+ reward_components={
202
+ "correctness": 0.0,
203
+ "step_discount": 1.0 if attempt_number == 1 else (0.85 if attempt_number == 2 else 0.70),
204
+ "progress_delta": 0.0,
205
+ },
206
  )
207
+ self.last_results = []
208
+ self.previous_execution_status = observation.execution_status
209
+ self._record_metrics(observation)
210
+ if done:
211
+ self._finalize_episode(observation)
212
  return observation
213
 
214
+ _, metadata = self._verify_submission(action.code)
215
  self.last_results = list(metadata.get("results", []))
216
+ hidden_pass_rate = float(metadata.get("hidden_pass_rate", metadata.get("pass_rate", 0.0)))
217
+ visible_pass_rate = float(metadata.get("visible_pass_rate", 0.0))
218
+ execution_status = str(metadata.get("execution_status", "completed"))
219
+ done = hidden_pass_rate == 1.0 or attempt_number >= MAX_STEPS_PER_EPISODE
220
+ reward, reward_components = self._shape_reward(
221
+ pass_rate=hidden_pass_rate,
222
+ step_number=attempt_number,
223
+ execution_status=execution_status,
224
+ previous_pass_rate=previous_pass_rate,
225
+ done=done,
226
+ )
227
+ feedback = self._format_feedback(
228
+ results=self.last_results,
229
+ attempt_number=attempt_number,
230
+ previous_status=previous_status,
231
+ execution_status=execution_status,
232
+ hidden_pass_rate=hidden_pass_rate,
233
+ visible_pass_rate=visible_pass_rate,
234
+ )
235
  observation = self._build_observation(
236
  reward=reward,
237
+ done=done,
238
+ feedback=feedback,
239
+ pass_rate=hidden_pass_rate,
240
+ visible_pass_rate=visible_pass_rate,
241
+ hidden_pass_rate=hidden_pass_rate,
242
  syntax_valid=True,
243
+ execution_status=execution_status,
244
  timeout_count=int(metadata.get("timeout_count", 0)),
245
  runtime_error_count=int(metadata.get("runtime_error_count", 0)),
246
  invalid_output_count=int(metadata.get("invalid_output_count", 0)),
247
  wrong_answer_count=int(metadata.get("wrong_answer_count", 0)),
248
  format_compliance=float(metadata.get("format_compliance", 0.0)),
249
+ reward_components=reward_components,
 
 
 
250
  generator_reward_signal=float(metadata.get("generator_reward", 0.0)),
251
  )
252
+ self.previous_execution_status = observation.execution_status
253
+ self._record_metrics(observation)
254
+ if done:
255
+ self._finalize_episode(observation)
256
  return observation
257
 
258
  @property
 
279
  ) -> AdaptObservation:
280
  public_problem = self._public_problem_view()
281
  return AdaptObservation(
282
+ session_id=self.session_id,
283
  problem_id=self.problem.get("problem_id", ""),
284
  problem_type=self.problem.get("problem_type", ""),
285
  difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
286
+ attempt_number=self._state.step_count,
287
+ max_steps=MAX_STEPS_PER_EPISODE,
288
  problem=public_problem.get("problem", ""),
289
  input_format=public_problem.get("input_format", ""),
290
  constraints=public_problem.get("constraints", ""),
291
  feedback=feedback,
292
+ pass_rate=round(float(pass_rate), 4),
293
+ visible_pass_rate=round(float(visible_pass_rate), 4),
294
+ hidden_pass_rate=round(float(hidden_pass_rate), 4),
295
  syntax_valid=syntax_valid,
296
  execution_status=execution_status,
297
  timeout_count=timeout_count,
298
  runtime_error_count=runtime_error_count,
299
  invalid_output_count=invalid_output_count,
300
  wrong_answer_count=wrong_answer_count,
301
+ format_compliance=round(float(format_compliance), 4),
302
  reward_components=reward_components or {},
303
  generator_reward_signal=round(float(generator_reward_signal), 4),
304
  reward=round(max(0.0, min(1.0, reward)), 4),
 
309
  self,
310
  generated_problem: dict[str, Any] | None,
311
  problem_id: str | None,
312
+ family_weights: dict[str, float] | None,
313
  ) -> dict[str, Any]:
314
+ candidate = generated_problem or self.generator.generate_problem(
315
  self.difficulty,
316
  self.history,
317
  problem_id=problem_id,
318
+ family_weights=family_weights,
319
  )
320
  if validate_problem(candidate):
321
  return candidate
322
+ fallback = self.generator.generate_problem(
323
+ self.difficulty,
324
+ self.history,
325
+ problem_id=problem_id,
326
+ family_weights=family_weights,
327
+ )
328
  if not validate_problem(fallback):
329
  raise ValueError("Generator produced an invalid problem twice in a row.")
330
  return fallback
 
333
  try:
334
  from verifier.verifier import verify
335
  except ImportError as exc:
336
+ return 0.0, {
337
+ "feedback": f"Verifier unavailable: {exc}",
338
+ "execution_status": "verifier_error",
339
+ "results": [],
340
+ }
341
 
342
  try:
343
  reward, metadata = verify(code, self.test_cases)
344
  except Exception as exc:
345
+ return 0.0, {
346
+ "feedback": f"Verifier crashed: {exc}",
347
+ "execution_status": "verifier_error",
348
+ "results": [],
349
+ }
350
 
351
  metadata = dict(metadata or {})
352
  diversity_bonus = self._diversity_bonus(self.problem.get("problem_type", ""))
353
  validity_bonus = float(self.problem.get("validity_bonus", 0.0))
354
+ hidden_pass_rate = float(metadata.get("hidden_pass_rate", metadata.get("pass_rate", 0.0)))
355
  metadata["generator_reward"] = generator_reward(
356
+ hidden_pass_rate,
357
  diversity_bonus=diversity_bonus,
358
  validity_bonus=validity_bonus,
359
  )
360
  return float(reward), metadata
361
 
362
+ def _shape_reward(
363
+ self,
364
+ pass_rate: float,
365
+ step_number: int,
366
+ execution_status: str,
367
+ previous_pass_rate: float,
368
+ done: bool,
369
+ ) -> tuple[float, dict[str, float]]:
370
+ step_discount = 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70)
371
+ progress_delta = max(0.0, float(pass_rate) - float(previous_pass_rate))
372
+
373
+ if execution_status in {"timeout", "syntax_error", "safety_violation"}:
374
+ reward = 0.0
375
+ elif pass_rate == 1.0:
376
+ reward = compute_reward(
377
+ pass_rate=pass_rate,
378
+ step_number=step_number,
379
+ execution_status=execution_status,
380
+ format_compliance=0.0,
381
+ )
382
+ elif done:
383
+ reward = 0.0
384
+ else:
385
+ reward = round(0.1 * progress_delta, 4)
386
+
387
+ return reward, {
388
+ "correctness": round(float(pass_rate), 4),
389
+ "step_discount": round(step_discount, 4),
390
+ "progress_delta": round(progress_delta, 4),
391
+ "reward": round(float(reward), 4),
392
+ }
393
 
394
+ def _format_feedback(
395
+ self,
396
+ results: list[dict[str, Any]],
397
+ attempt_number: int,
398
+ previous_status: str,
399
+ execution_status: str,
400
+ hidden_pass_rate: float,
401
+ visible_pass_rate: float,
402
+ ) -> str:
403
+ lines = [
404
+ f"Attempt {attempt_number}/{MAX_STEPS_PER_EPISODE}.",
405
+ f"Previous attempt status: {previous_status}.",
406
+ f"Current execution status: {execution_status}.",
407
+ f"Hidden pass rate: {hidden_pass_rate:.2f}. Visible pass rate: {visible_pass_rate:.2f}.",
408
+ ]
409
+
410
+ failed_tests = self._summarize_failed_tests(results)
411
+ if failed_tests:
412
+ lines.append("Failed tests:")
413
+ lines.extend(failed_tests)
414
+ elif hidden_pass_rate == 1.0:
415
+ lines.append("All hidden tests passed.")
416
+ else:
417
+ lines.append("No failing test details were available.")
418
+
419
+ return "\n".join(lines)
420
+
421
+ def _format_static_feedback(
422
+ self,
423
+ attempt_number: int,
424
+ previous_status: str,
425
+ execution_status: str,
426
+ details: str,
427
+ ) -> str:
428
+ return "\n".join(
429
+ [
430
+ f"Attempt {attempt_number}/{MAX_STEPS_PER_EPISODE}.",
431
+ f"Previous attempt status: {previous_status}.",
432
+ f"Current execution status: {execution_status}.",
433
+ details,
434
+ ]
435
+ )
436
 
437
+ def _summarize_failed_tests(self, results: list[dict[str, Any]]) -> list[str]:
438
+ summaries: list[str] = []
439
+ for result in results:
440
+ if result.get("passed", False):
441
+ continue
442
+ visibility = str(result.get("visibility", "hidden"))
443
+ label = f"{visibility.title()} test #{int(result.get('index', 0)) + 1}"
444
+ status = str(result.get("status", "unknown"))
445
+ if visibility == "visible":
446
+ actual = str(result.get("stdout", "")).strip()
447
+ expected = str(result.get("expected", "")).strip()
448
+ details = []
449
+ if expected:
450
+ details.append(f"expected={expected}")
451
+ if actual:
452
+ details.append(f"got={actual}")
453
+ if result.get("stderr"):
454
+ details.append("stderr_present")
455
+ if details:
456
+ summaries.append(f"- {label}: {status} ({', '.join(details)})")
457
+ else:
458
+ summaries.append(f"- {label}: {status}")
459
+ else:
460
+ summaries.append(f"- {label}: {status}")
461
+ return summaries
462
 
463
  def _record_metrics(self, observation: AdaptObservation) -> None:
464
+ attempt_record = {
465
+ "attempt_number": observation.attempt_number,
466
+ "reward": float(observation.reward or 0.0),
467
+ "pass_rate": float(observation.pass_rate),
468
+ "visible_pass_rate": float(observation.visible_pass_rate),
469
+ "execution_status": observation.execution_status,
470
+ "feedback": observation.feedback,
471
+ "done": bool(observation.done),
472
+ }
473
+ self.attempt_history.append(attempt_record)
474
  self._state.last_reward = float(observation.reward or 0.0)
475
  self._state.last_pass_rate = observation.pass_rate
476
  self._state.last_feedback = observation.feedback
477
+ self._state.last_execution_status = observation.execution_status
478
  self._state.generator_reward_signal = observation.generator_reward_signal
479
+ self._state.history = {"attempts": list(self.attempt_history)}
480
+ self._state.generated_problem = self._public_problem_view()
 
 
 
481
  self._state.recent_metrics = {
482
  "difficulty_tier": self.difficulty,
483
  "difficulty_label": self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
484
+ "visible_pass_rate": observation.visible_pass_rate,
485
  "pass_rate": observation.pass_rate,
486
  "execution_status": observation.execution_status,
487
  "timeout_count": observation.timeout_count,
 
492
  "reward_components": dict(observation.reward_components),
493
  }
494
 
495
+ def _finalize_episode(self, observation: AdaptObservation) -> None:
496
+ self.episode_done = True
497
+ self._update_history(observation.pass_rate, observation.generator_reward_signal)
498
+
499
+ def _update_history(self, pass_rate: float, generator_signal: float) -> None:
500
+ self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
501
+ self.history["problem_types"].append(self.problem.get("problem_type", ""))
502
+ self.history["problem_signatures"].append(self.problem.get("problem_id", ""))
503
+ self.history["generator_rewards"].append(round(float(generator_signal), 4))
504
+ self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
505
+
506
+ for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
507
+ values = self.history[key]
508
+ if len(values) > self.max_history:
509
+ del values[:-self.max_history]
510
 
511
  def _public_problem_view(self) -> dict[str, str]:
512
  visible = dict(self.problem.get("visible_problem", {}))
513
+ base_problem = visible.get("problem", self.problem.get("problem", ""))
514
+ examples = self._format_examples()
515
+ if examples:
516
+ base_problem = f"{base_problem}\n\nExamples:\n{examples}"
517
  return {
518
+ "problem": base_problem,
519
  "input_format": visible.get("input_format", self.problem.get("input_format", "")),
520
  "constraints": visible.get("constraints", self.problem.get("constraints", "")),
521
  }
522
 
523
+ def _format_examples(self) -> str:
524
+ visible_cases = [test_case for test_case in self.test_cases if test_case.get("is_visible", False)]
525
+ if not visible_cases:
526
+ return ""
527
+ chunks = []
528
+ for test_case in visible_cases:
529
+ chunks.append(
530
+ f"Input:\n{test_case['input']}Expected Output:\n{test_case['output']}\n"
531
+ )
532
+ return "\n".join(chunks).rstrip()
533
+
534
  def _diversity_bonus(self, problem_type: str) -> float:
535
+ recent_types = list(self.history.get("problem_types", [])[-6:])
536
  if not recent_types:
537
  return 0.1
538
  if problem_type in recent_types:
env/generator.py CHANGED
@@ -1,13 +1,15 @@
1
  from __future__ import annotations
2
 
3
  import hashlib
4
- import math
5
  import random
 
6
  from dataclasses import dataclass
7
  from typing import Any, Callable
8
 
9
- VISIBLE_TEST_COUNT = 0
10
- MIN_TEST_CASES = 5
 
 
11
 
12
 
13
  @dataclass(frozen=True)
@@ -17,9 +19,9 @@ class ProblemTemplate:
17
  title: str
18
  input_format: str
19
  constraints: str
20
- statement_builder: Callable[[dict[str, Any]], str]
21
  solver: Callable[[str], str]
22
- case_builder: Callable[[random.Random, float], list[str]]
23
 
24
 
25
  def generator_reward(
@@ -54,14 +56,15 @@ def validate_problem(problem_dict: dict[str, Any]) -> bool:
54
  return False
55
 
56
  test_cases = problem_dict.get("test_cases")
57
- if not isinstance(test_cases, list) or len(test_cases) < MIN_TEST_CASES:
58
  return False
59
 
60
  seen_inputs: set[str] = set()
61
  distinct_outputs: set[str] = set()
62
  visible_count = 0
 
63
 
64
- for test_case in test_cases:
65
  if not isinstance(test_case, dict):
66
  return False
67
 
@@ -76,12 +79,26 @@ def validate_problem(problem_dict: dict[str, Any]) -> bool:
76
  return False
77
  seen_inputs.add(raw_input)
78
  distinct_outputs.add(raw_output.strip())
79
- visible_count += 1 if is_visible else 0
80
 
81
- if visible_count != VISIBLE_TEST_COUNT:
 
 
 
 
 
 
 
 
 
 
82
  return False
83
 
84
- if len(distinct_outputs) < max(3, len(test_cases) // 3):
 
 
 
 
 
85
  return False
86
 
87
  return True
@@ -94,46 +111,46 @@ class GeneratorAgent:
94
  self.deterministic = deterministic
95
  self.templates = _build_templates()
96
 
97
- def generate(
98
  self,
99
  difficulty_level: int | float | str,
100
  history: dict[str, Any] | None,
101
  problem_id: str | None = None,
 
102
  ) -> dict[str, Any]:
103
  history = history or {}
104
  target_tier = _difficulty_to_tier(difficulty_level)
105
- adjusted_tier = self._adjust_tier(target_tier, history)
106
- rng = self._rng_for(adjusted_tier, history, problem_id)
107
- template = self._choose_template(adjusted_tier, history, rng, forced_problem_type=problem_id)
108
-
109
- for attempt in range(10):
110
- params = {
111
- "window": 3 + adjusted_tier,
112
- "modulus": 10 + 5 * adjusted_tier,
113
- "max_n": 8 + adjusted_tier * 4,
114
- "attempt": attempt,
115
- }
116
- raw_cases = template.case_builder(rng, 0.2 + adjusted_tier * 0.25)
117
  test_cases = [
118
  {
119
  "input": case_input,
120
  "output": template.solver(case_input),
121
- "is_visible": False,
122
  }
123
- for case_input in raw_cases
124
  ]
125
  signature = self._problem_signature(template.problem_type, test_cases)
126
  problem = {
127
  "problem_id": f"{template.problem_type}_{signature[:8]}",
128
  "problem_type": template.problem_type,
129
- "difficulty": round(self._tier_to_scalar(adjusted_tier), 4),
130
- "difficulty_label": DIFFICULTY_LABELS[adjusted_tier],
131
- "problem": template.statement_builder(params),
132
  "input_format": template.input_format,
133
  "constraints": template.constraints,
134
  "test_cases": test_cases,
135
  "visible_problem": {
136
- "problem": template.statement_builder(params),
137
  "input_format": template.input_format,
138
  "constraints": template.constraints,
139
  },
@@ -145,17 +162,19 @@ class GeneratorAgent:
145
 
146
  raise ValueError(f"Unable to generate a valid problem for template {template.problem_type}")
147
 
148
- def _adjust_tier(self, target_tier: int, history: dict[str, Any]) -> int:
149
- recent_pass_rates = [float(value) for value in history.get("recent_pass_rates", [])[-5:]]
150
- if not recent_pass_rates:
151
- return target_tier
152
-
153
- moving_average = sum(recent_pass_rates) / len(recent_pass_rates)
154
- if moving_average > 0.8:
155
- return min(3, target_tier + 1)
156
- if moving_average < 0.2:
157
- return max(1, target_tier - 1)
158
- return target_tier
 
 
159
 
160
  def _choose_template(
161
  self,
@@ -163,6 +182,7 @@ class GeneratorAgent:
163
  history: dict[str, Any],
164
  rng: random.Random,
165
  forced_problem_type: str | None = None,
 
166
  ) -> ProblemTemplate:
167
  eligible = [template for template in self.templates if template.difficulty_tier == tier]
168
  if not eligible:
@@ -176,27 +196,37 @@ class GeneratorAgent:
176
  if template.problem_type == forced_problem_type:
177
  return template
178
 
179
- recent_types = list(history.get("problem_types", [])[-4:])
180
- weighted: list[tuple[float, ProblemTemplate]] = []
181
  for template in eligible:
182
- repetition_penalty = 0.35 if template.problem_type in recent_types else 0.0
183
- jitter = rng.random() * 0.2
184
- weighted.append((1.0 - repetition_penalty + jitter, template))
185
- weighted.sort(key=lambda item: item[0], reverse=True)
186
- return weighted[0][1]
 
 
187
 
188
  def _rng_for(
189
  self,
190
  tier: int,
191
  history: dict[str, Any],
192
  problem_id: str | None,
 
193
  ) -> random.Random:
 
 
 
194
  seed_material = {
195
  "tier": tier,
196
  "problem_id": problem_id or "",
197
  "pass_rates": [round(float(value), 4) for value in history.get("recent_pass_rates", [])[-8:]],
198
  "problem_types": list(history.get("problem_types", [])[-8:]),
199
  "episode_index": int(history.get("episode_index", 0)),
 
 
 
 
200
  }
201
  digest = hashlib.sha256(repr(seed_material).encode("utf-8")).hexdigest()
202
  return random.Random(int(digest[:16], 16))
@@ -246,7 +276,7 @@ def _build_templates() -> list[ProblemTemplate]:
246
  title="Sum Even Numbers",
247
  input_format="The first line contains n. The second line contains n space-separated integers.",
248
  constraints="1 <= n <= 12; -100 <= values[i] <= 100",
249
- statement_builder=lambda _: (
250
  "Given a list of integers, print the sum of the numbers that are even. "
251
  "If no number is even, print 0."
252
  ),
@@ -259,19 +289,71 @@ def _build_templates() -> list[ProblemTemplate]:
259
  title="Range Span",
260
  input_format="The first line contains n. The second line contains n space-separated integers.",
261
  constraints="2 <= n <= 12; -100 <= values[i] <= 100",
262
- statement_builder=lambda _: (
263
  "Given a list of integers, print the difference between the maximum and minimum value."
264
  ),
265
  solver=_solve_range_span,
266
  case_builder=_build_range_span_cases,
267
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  ProblemTemplate(
269
  problem_type="count_local_peaks",
270
  difficulty_tier=2,
271
  title="Count Local Peaks",
272
  input_format="The first line contains n. The second line contains n space-separated integers.",
273
- constraints="3 <= n <= 14; -100 <= values[i] <= 100",
274
- statement_builder=lambda _: (
275
  "Count how many indices i are local peaks, meaning values[i] is strictly greater than both "
276
  "values[i-1] and values[i+1]. The first and last element can never be peaks."
277
  ),
@@ -283,20 +365,81 @@ def _build_templates() -> list[ProblemTemplate]:
283
  difficulty_tier=2,
284
  title="Longest Non-Decreasing Run",
285
  input_format="The first line contains n. The second line contains n space-separated integers.",
286
- constraints="1 <= n <= 16; -100 <= values[i] <= 100",
287
- statement_builder=lambda _: (
288
  "Find the length of the longest contiguous subarray whose values are non-decreasing."
289
  ),
290
  solver=_solve_longest_non_decreasing_run,
291
  case_builder=_build_run_cases,
292
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  ProblemTemplate(
294
  problem_type="smallest_most_frequent",
295
  difficulty_tier=3,
296
  title="Smallest Most Frequent",
297
  input_format="The first line contains n. The second line contains n space-separated integers.",
298
- constraints="1 <= n <= 18; -30 <= values[i] <= 30",
299
- statement_builder=lambda _: (
300
  "Print the value that appears most often in the array. If several values have the same highest "
301
  "frequency, print the smallest of them."
302
  ),
@@ -308,79 +451,343 @@ def _build_templates() -> list[ProblemTemplate]:
308
  difficulty_tier=3,
309
  title="Reverse Words",
310
  input_format="A single line containing one or more words separated by spaces.",
311
- constraints="1 <= line length <= 80",
312
- statement_builder=lambda _: (
313
  "Read a line of text and print the words in reverse order. Multiple spaces in the input should "
314
  "be treated as a single separator."
315
  ),
316
  solver=_solve_reverse_words,
317
  case_builder=_build_reverse_word_cases,
318
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  ]
 
 
 
 
 
 
 
320
 
321
 
322
- def _build_sum_even_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
323
- size = 5 + math.ceil(difficulty_scalar * 5)
324
- cases = set()
325
- while len(cases) < 6:
326
- numbers = [rng.randint(-25, 25) for _ in range(size + rng.randint(0, 3))]
327
- if all(number % 2 for number in numbers):
328
- numbers[0] = 0
329
- cases.add(_array_case(numbers))
330
- return list(cases)
331
-
332
-
333
- def _build_range_span_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
334
- size = 4 + math.ceil(difficulty_scalar * 6)
335
- cases = set()
336
- while len(cases) < 6:
337
- numbers = [rng.randint(-40, 40) for _ in range(size + rng.randint(0, 3))]
338
- if len(set(numbers)) == 1:
339
- numbers[-1] += 3
340
- cases.add(_array_case(numbers))
341
- return list(cases)
342
-
343
-
344
- def _build_peak_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
345
- size = 5 + math.ceil(difficulty_scalar * 6)
346
- cases = set()
347
- while len(cases) < 6:
348
- numbers = []
349
- current = rng.randint(-10, 10)
350
- for index in range(size + rng.randint(0, 4)):
351
- delta = rng.randint(-6, 6)
352
- if index % 2 == 1:
353
- delta = abs(delta) + 1
354
- current += delta
355
- numbers.append(current)
356
- numbers[0] -= 5
357
- numbers[-1] -= 5
358
- cases.add(_array_case(numbers))
359
- return list(cases)
360
-
361
-
362
- def _build_run_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
363
- size = 6 + math.ceil(difficulty_scalar * 6)
364
- cases = set()
365
- while len(cases) < 6:
366
- numbers = [rng.randint(-20, 20)]
367
- for _ in range(size + rng.randint(0, 4) - 1):
368
- numbers.append(numbers[-1] + rng.randint(-5, 5))
369
- cases.add(_array_case(numbers))
370
- return list(cases)
371
-
372
-
373
- def _build_frequency_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
374
- size = 8 + math.ceil(difficulty_scalar * 6)
375
- cases = set()
376
- while len(cases) < 6:
377
- numbers = [rng.randint(-6, 6) for _ in range(size + rng.randint(0, 5))]
378
- numbers.extend([rng.choice(numbers), rng.choice(numbers)])
379
- cases.add(_array_case(numbers))
380
- return list(cases)
381
-
382
-
383
- def _build_reverse_word_cases(rng: random.Random, difficulty_scalar: float) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  vocabulary = [
385
  "graph",
386
  "queue",
@@ -395,21 +802,170 @@ def _build_reverse_word_cases(rng: random.Random, difficulty_scalar: float) -> l
395
  "node",
396
  "edge",
397
  ]
398
- word_count = 4 + math.ceil(difficulty_scalar * 4)
399
- cases = set()
400
- while len(cases) < 6:
401
- words = [rng.choice(vocabulary) for _ in range(word_count + rng.randint(0, 2))]
402
- spacer = " " * rng.randint(1, 3)
403
- prefix = " " * rng.randint(0, 2)
404
- suffix = " " * rng.randint(0, 2)
405
- cases.add(f"{prefix}{spacer.join(words)}{suffix}\n")
406
- return list(cases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
 
409
  def _array_case(numbers: list[int]) -> str:
410
  return f"{len(numbers)}\n{' '.join(str(number) for number in numbers)}\n"
411
 
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  def _solve_sum_even_numbers(stdin: str) -> str:
414
  _, numbers = _parse_int_array(stdin)
415
  return str(sum(number for number in numbers if number % 2 == 0))
@@ -420,6 +976,47 @@ def _solve_range_span(stdin: str) -> str:
420
  return str(max(numbers) - min(numbers))
421
 
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  def _solve_count_local_peaks(stdin: str) -> str:
424
  _, numbers = _parse_int_array(stdin)
425
  peaks = 0
@@ -442,11 +1039,60 @@ def _solve_longest_non_decreasing_run(stdin: str) -> str:
442
  return str(best)
443
 
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  def _solve_smallest_most_frequent(stdin: str) -> str:
446
  _, numbers = _parse_int_array(stdin)
447
- counts: dict[int, int] = {}
448
- for number in numbers:
449
- counts[number] = counts.get(number, 0) + 1
450
  best_count = max(counts.values())
451
  best_value = min(number for number, count in counts.items() if count == best_count)
452
  return str(best_value)
@@ -457,6 +1103,76 @@ def _solve_reverse_words(stdin: str) -> str:
457
  return " ".join(reversed(words))
458
 
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
461
  lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
462
  n = int(lines[0])
@@ -464,3 +1180,146 @@ def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
464
  if len(numbers) != n:
465
  raise ValueError(f"Expected {n} integers, received {len(numbers)}")
466
  return n, numbers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import hashlib
 
4
  import random
5
+ from collections import Counter, deque
6
  from dataclasses import dataclass
7
  from typing import Any, Callable
8
 
9
+ VISIBLE_TEST_COUNT = 2
10
+ HIDDEN_TEST_COUNT = 8
11
+ TOTAL_TEST_CASES = VISIBLE_TEST_COUNT + HIDDEN_TEST_COUNT
12
+ MIN_TEST_CASES = TOTAL_TEST_CASES
13
 
14
 
15
  @dataclass(frozen=True)
 
19
  title: str
20
  input_format: str
21
  constraints: str
22
+ statement_builder: Callable[[], str]
23
  solver: Callable[[str], str]
24
+ case_builder: Callable[[random.Random], list[str]]
25
 
26
 
27
  def generator_reward(
 
56
  return False
57
 
58
  test_cases = problem_dict.get("test_cases")
59
+ if not isinstance(test_cases, list) or len(test_cases) != TOTAL_TEST_CASES:
60
  return False
61
 
62
  seen_inputs: set[str] = set()
63
  distinct_outputs: set[str] = set()
64
  visible_count = 0
65
+ hidden_count = 0
66
 
67
+ for index, test_case in enumerate(test_cases):
68
  if not isinstance(test_case, dict):
69
  return False
70
 
 
79
  return False
80
  seen_inputs.add(raw_input)
81
  distinct_outputs.add(raw_output.strip())
 
82
 
83
+ if index < VISIBLE_TEST_COUNT and not is_visible:
84
+ return False
85
+ if index >= VISIBLE_TEST_COUNT and is_visible:
86
+ return False
87
+
88
+ if is_visible:
89
+ visible_count += 1
90
+ else:
91
+ hidden_count += 1
92
+
93
+ if visible_count != VISIBLE_TEST_COUNT or hidden_count != HIDDEN_TEST_COUNT:
94
  return False
95
 
96
+ normalized_outputs = {output.strip().lower() for output in distinct_outputs}
97
+ min_output_diversity = 2 if normalized_outputs.issubset({"yes", "no", "true", "false", "0", "1"}) else max(
98
+ 3,
99
+ len(test_cases) // 3,
100
+ )
101
+ if len(distinct_outputs) < min_output_diversity:
102
  return False
103
 
104
  return True
 
111
  self.deterministic = deterministic
112
  self.templates = _build_templates()
113
 
114
+ def generate_problem(
115
  self,
116
  difficulty_level: int | float | str,
117
  history: dict[str, Any] | None,
118
  problem_id: str | None = None,
119
+ family_weights: dict[str, float] | None = None,
120
  ) -> dict[str, Any]:
121
  history = history or {}
122
  target_tier = _difficulty_to_tier(difficulty_level)
123
+ rng = self._rng_for(target_tier, history, problem_id, family_weights or {})
124
+ template = self._choose_template(
125
+ target_tier,
126
+ history,
127
+ rng,
128
+ forced_problem_type=problem_id,
129
+ family_weights=family_weights or {},
130
+ )
131
+
132
+ for _ in range(20):
133
+ raw_cases = template.case_builder(rng)
 
134
  test_cases = [
135
  {
136
  "input": case_input,
137
  "output": template.solver(case_input),
138
+ "is_visible": index < VISIBLE_TEST_COUNT,
139
  }
140
+ for index, case_input in enumerate(raw_cases)
141
  ]
142
  signature = self._problem_signature(template.problem_type, test_cases)
143
  problem = {
144
  "problem_id": f"{template.problem_type}_{signature[:8]}",
145
  "problem_type": template.problem_type,
146
+ "difficulty": round(self._tier_to_scalar(target_tier), 4),
147
+ "difficulty_label": DIFFICULTY_LABELS[target_tier],
148
+ "problem": template.statement_builder(),
149
  "input_format": template.input_format,
150
  "constraints": template.constraints,
151
  "test_cases": test_cases,
152
  "visible_problem": {
153
+ "problem": template.statement_builder(),
154
  "input_format": template.input_format,
155
  "constraints": template.constraints,
156
  },
 
162
 
163
  raise ValueError(f"Unable to generate a valid problem for template {template.problem_type}")
164
 
165
+ def generate(
166
+ self,
167
+ difficulty_level: int | float | str,
168
+ history: dict[str, Any] | None,
169
+ problem_id: str | None = None,
170
+ family_weights: dict[str, float] | None = None,
171
+ ) -> dict[str, Any]:
172
+ return self.generate_problem(
173
+ difficulty_level=difficulty_level,
174
+ history=history,
175
+ problem_id=problem_id,
176
+ family_weights=family_weights,
177
+ )
178
 
179
  def _choose_template(
180
  self,
 
182
  history: dict[str, Any],
183
  rng: random.Random,
184
  forced_problem_type: str | None = None,
185
+ family_weights: dict[str, float] | None = None,
186
  ) -> ProblemTemplate:
187
  eligible = [template for template in self.templates if template.difficulty_tier == tier]
188
  if not eligible:
 
196
  if template.problem_type == forced_problem_type:
197
  return template
198
 
199
+ recent_types = list(history.get("problem_types", [])[-6:])
200
+ weights: list[float] = []
201
  for template in eligible:
202
+ base_weight = float((family_weights or {}).get(template.problem_type, 1.0))
203
+ base_weight = max(base_weight, 1e-6)
204
+ if template.problem_type in recent_types:
205
+ base_weight *= 0.35
206
+ weights.append(base_weight)
207
+
208
+ return rng.choices(eligible, weights=weights, k=1)[0]
209
 
210
  def _rng_for(
211
  self,
212
  tier: int,
213
  history: dict[str, Any],
214
  problem_id: str | None,
215
+ family_weights: dict[str, float],
216
  ) -> random.Random:
217
+ if not self.deterministic:
218
+ return random.Random()
219
+
220
  seed_material = {
221
  "tier": tier,
222
  "problem_id": problem_id or "",
223
  "pass_rates": [round(float(value), 4) for value in history.get("recent_pass_rates", [])[-8:]],
224
  "problem_types": list(history.get("problem_types", [])[-8:]),
225
  "episode_index": int(history.get("episode_index", 0)),
226
+ "family_weights": {
227
+ key: round(float(value), 4)
228
+ for key, value in sorted(family_weights.items())
229
+ },
230
  }
231
  digest = hashlib.sha256(repr(seed_material).encode("utf-8")).hexdigest()
232
  return random.Random(int(digest[:16], 16))
 
276
  title="Sum Even Numbers",
277
  input_format="The first line contains n. The second line contains n space-separated integers.",
278
  constraints="1 <= n <= 12; -100 <= values[i] <= 100",
279
+ statement_builder=lambda: (
280
  "Given a list of integers, print the sum of the numbers that are even. "
281
  "If no number is even, print 0."
282
  ),
 
289
  title="Range Span",
290
  input_format="The first line contains n. The second line contains n space-separated integers.",
291
  constraints="2 <= n <= 12; -100 <= values[i] <= 100",
292
+ statement_builder=lambda: (
293
  "Given a list of integers, print the difference between the maximum and minimum value."
294
  ),
295
  solver=_solve_range_span,
296
  case_builder=_build_range_span_cases,
297
  ),
298
+ ProblemTemplate(
299
+ problem_type="count_vowels",
300
+ difficulty_tier=1,
301
+ title="Count Vowels",
302
+ input_format="A single line containing lowercase or uppercase letters and spaces.",
303
+ constraints="1 <= line length <= 80",
304
+ statement_builder=lambda: (
305
+ "Count how many vowels appear in the input line. Treat a, e, i, o, u as vowels "
306
+ "and ignore case."
307
+ ),
308
+ solver=_solve_count_vowels,
309
+ case_builder=_build_count_vowels_cases,
310
+ ),
311
+ ProblemTemplate(
312
+ problem_type="max_consecutive_ones",
313
+ difficulty_tier=1,
314
+ title="Max Consecutive Ones",
315
+ input_format="A single line containing a binary string.",
316
+ constraints="1 <= string length <= 40",
317
+ statement_builder=lambda: (
318
+ "Print the length of the longest contiguous block of '1' characters in the binary string."
319
+ ),
320
+ solver=_solve_max_consecutive_ones,
321
+ case_builder=_build_max_consecutive_ones_cases,
322
+ ),
323
+ ProblemTemplate(
324
+ problem_type="fizzbuzz_variant",
325
+ difficulty_tier=1,
326
+ title="FizzBuzz Variant",
327
+ input_format="The first line contains n a b. The second line contains label_a and label_b.",
328
+ constraints="1 <= n <= 25; 2 <= a, b <= 9; labels contain only letters",
329
+ statement_builder=lambda: (
330
+ "For each integer from 1 to n, print label_a if the number is divisible by a, "
331
+ "label_b if it is divisible by b, and the concatenation label_a+label_b if it is divisible "
332
+ "by both. Otherwise print the number itself. Output all tokens on one line separated by spaces."
333
+ ),
334
+ solver=_solve_fizzbuzz_variant,
335
+ case_builder=_build_fizzbuzz_variant_cases,
336
+ ),
337
+ ProblemTemplate(
338
+ problem_type="running_total",
339
+ difficulty_tier=1,
340
+ title="Running Total",
341
+ input_format="The first line contains n. The second line contains n space-separated integers.",
342
+ constraints="1 <= n <= 14; -50 <= values[i] <= 50",
343
+ statement_builder=lambda: (
344
+ "Print the running total after each element of the array. Output the cumulative sums on one line "
345
+ "separated by spaces."
346
+ ),
347
+ solver=_solve_running_total,
348
+ case_builder=_build_running_total_cases,
349
+ ),
350
  ProblemTemplate(
351
  problem_type="count_local_peaks",
352
  difficulty_tier=2,
353
  title="Count Local Peaks",
354
  input_format="The first line contains n. The second line contains n space-separated integers.",
355
+ constraints="3 <= n <= 16; -100 <= values[i] <= 100",
356
+ statement_builder=lambda: (
357
  "Count how many indices i are local peaks, meaning values[i] is strictly greater than both "
358
  "values[i-1] and values[i+1]. The first and last element can never be peaks."
359
  ),
 
365
  difficulty_tier=2,
366
  title="Longest Non-Decreasing Run",
367
  input_format="The first line contains n. The second line contains n space-separated integers.",
368
+ constraints="1 <= n <= 18; -100 <= values[i] <= 100",
369
+ statement_builder=lambda: (
370
  "Find the length of the longest contiguous subarray whose values are non-decreasing."
371
  ),
372
  solver=_solve_longest_non_decreasing_run,
373
  case_builder=_build_run_cases,
374
  ),
375
+ ProblemTemplate(
376
+ problem_type="two_sum_count",
377
+ difficulty_tier=2,
378
+ title="Two Sum Count",
379
+ input_format="The first line contains n and target. The second line contains n space-separated integers.",
380
+ constraints="2 <= n <= 16; -50 <= values[i] <= 50",
381
+ statement_builder=lambda: (
382
+ "Count how many index pairs (i, j) with i < j have values[i] + values[j] equal to target."
383
+ ),
384
+ solver=_solve_two_sum_count,
385
+ case_builder=_build_two_sum_count_cases,
386
+ ),
387
+ ProblemTemplate(
388
+ problem_type="max_subarray_sum",
389
+ difficulty_tier=2,
390
+ title="Maximum Subarray Sum",
391
+ input_format="The first line contains n. The second line contains n space-separated integers.",
392
+ constraints="1 <= n <= 18; -50 <= values[i] <= 50",
393
+ statement_builder=lambda: (
394
+ "Print the maximum possible sum of a contiguous subarray."
395
+ ),
396
+ solver=_solve_max_subarray_sum,
397
+ case_builder=_build_max_subarray_sum_cases,
398
+ ),
399
+ ProblemTemplate(
400
+ problem_type="group_anagrams_count",
401
+ difficulty_tier=2,
402
+ title="Group Anagrams Count",
403
+ input_format="The first line contains n. The second line contains n space-separated lowercase words.",
404
+ constraints="1 <= n <= 12; each word length is between 1 and 8",
405
+ statement_builder=lambda: (
406
+ "Group words that are anagrams of each other. Print the number of distinct anagram groups."
407
+ ),
408
+ solver=_solve_group_anagrams_count,
409
+ case_builder=_build_group_anagrams_cases,
410
+ ),
411
+ ProblemTemplate(
412
+ problem_type="balanced_brackets",
413
+ difficulty_tier=2,
414
+ title="Balanced Brackets",
415
+ input_format="A single line containing only the characters ()[]{}.",
416
+ constraints="1 <= line length <= 50",
417
+ statement_builder=lambda: (
418
+ "Print YES if the bracket string is balanced and NO otherwise."
419
+ ),
420
+ solver=_solve_balanced_brackets,
421
+ case_builder=_build_balanced_brackets_cases,
422
+ ),
423
+ ProblemTemplate(
424
+ problem_type="matrix_diagonal_sum",
425
+ difficulty_tier=2,
426
+ title="Matrix Diagonal Sum",
427
+ input_format="The first line contains n. The next n lines each contain n space-separated integers.",
428
+ constraints="2 <= n <= 6; -20 <= matrix[i][j] <= 20",
429
+ statement_builder=lambda: (
430
+ "For the square matrix, print the sum of the primary diagonal and secondary diagonal. "
431
+ "If n is odd, count the center element only once."
432
+ ),
433
+ solver=_solve_matrix_diagonal_sum,
434
+ case_builder=_build_matrix_diagonal_sum_cases,
435
+ ),
436
  ProblemTemplate(
437
  problem_type="smallest_most_frequent",
438
  difficulty_tier=3,
439
  title="Smallest Most Frequent",
440
  input_format="The first line contains n. The second line contains n space-separated integers.",
441
+ constraints="1 <= n <= 20; -30 <= values[i] <= 30",
442
+ statement_builder=lambda: (
443
  "Print the value that appears most often in the array. If several values have the same highest "
444
  "frequency, print the smallest of them."
445
  ),
 
451
  difficulty_tier=3,
452
  title="Reverse Words",
453
  input_format="A single line containing one or more words separated by spaces.",
454
+ constraints="1 <= line length <= 120",
455
+ statement_builder=lambda: (
456
  "Read a line of text and print the words in reverse order. Multiple spaces in the input should "
457
  "be treated as a single separator."
458
  ),
459
  solver=_solve_reverse_words,
460
  case_builder=_build_reverse_word_cases,
461
  ),
462
+ ProblemTemplate(
463
+ problem_type="longest_common_subsequence",
464
+ difficulty_tier=3,
465
+ title="Longest Common Subsequence",
466
+ input_format="The first line contains string s. The second line contains string t.",
467
+ constraints="1 <= len(s), len(t) <= 18; strings contain lowercase letters",
468
+ statement_builder=lambda: (
469
+ "Print the length of the longest common subsequence of the two strings."
470
+ ),
471
+ solver=_solve_longest_common_subsequence,
472
+ case_builder=_build_lcs_cases,
473
+ ),
474
+ ProblemTemplate(
475
+ problem_type="word_ladder_steps",
476
+ difficulty_tier=3,
477
+ title="Word Ladder Steps",
478
+ input_format="The first line contains start and target. The second line contains n. The third line contains n space-separated words.",
479
+ constraints="All words have the same length between 3 and 5; 1 <= n <= 14",
480
+ statement_builder=lambda: (
481
+ "You may change one character at a time. Every intermediate word and the target word must appear "
482
+ "in the given word list. Print the minimum number of single-character changes needed to transform "
483
+ "start into target, or -1 if it is impossible."
484
+ ),
485
+ solver=_solve_word_ladder_steps,
486
+ case_builder=_build_word_ladder_cases,
487
+ ),
488
+ ProblemTemplate(
489
+ problem_type="merge_intervals",
490
+ difficulty_tier=3,
491
+ title="Merge Intervals",
492
+ input_format="The first line contains n. The next n lines each contain start and end.",
493
+ constraints="1 <= n <= 12; -20 <= start <= end <= 30",
494
+ statement_builder=lambda: (
495
+ "Merge all overlapping intervals and print how many intervals remain after merging."
496
+ ),
497
+ solver=_solve_merge_intervals,
498
+ case_builder=_build_merge_intervals_cases,
499
+ ),
500
+ ProblemTemplate(
501
+ problem_type="min_coins",
502
+ difficulty_tier=3,
503
+ title="Minimum Coins",
504
+ input_format="The first line contains n and target. The second line contains n distinct positive coin values.",
505
+ constraints="1 <= n <= 8; 1 <= target <= 40; 1 <= coin values <= 20",
506
+ statement_builder=lambda: (
507
+ "Print the minimum number of coins needed to make exactly target using unlimited copies of the given "
508
+ "coin values. Print -1 if it is impossible."
509
+ ),
510
+ solver=_solve_min_coins,
511
+ case_builder=_build_min_coins_cases,
512
+ ),
513
+ ProblemTemplate(
514
+ problem_type="rotate_matrix_90",
515
+ difficulty_tier=3,
516
+ title="Rotate Matrix 90 Degrees",
517
+ input_format="The first line contains n. The next n lines each contain n space-separated integers.",
518
+ constraints="2 <= n <= 5; -20 <= matrix[i][j] <= 20",
519
+ statement_builder=lambda: (
520
+ "Rotate the square matrix 90 degrees clockwise and print the rotated matrix flattened in row-major "
521
+ "order on one line separated by spaces."
522
+ ),
523
+ solver=_solve_rotate_matrix_90,
524
+ case_builder=_build_rotate_matrix_cases,
525
+ ),
526
+ ]
527
+
528
+
529
+ def _build_sum_even_cases(rng: random.Random) -> list[str]:
530
+ visible_pool = [
531
+ _array_case([2, 3, 4]),
532
+ _array_case([1, 3, 5, 7]),
533
+ _array_case([0, -2, 5, 8]),
534
+ ]
535
+ return _cases_from_pool_and_factory(rng, visible_pool, _sum_even_hidden_case)
536
+
537
+
538
+ def _sum_even_hidden_case(rng: random.Random) -> str:
539
+ length = rng.randint(5, 12)
540
+ numbers = [rng.randint(-50, 50) for _ in range(length)]
541
+ if all(number % 2 for number in numbers):
542
+ numbers[rng.randrange(length)] = rng.choice([-8, -2, 0, 6, 14])
543
+ return _array_case(numbers)
544
+
545
+
546
+ def _build_range_span_cases(rng: random.Random) -> list[str]:
547
+ visible_pool = [
548
+ _array_case([1, 4, 9]),
549
+ _array_case([-2, -2, -2, 1]),
550
+ _array_case([8, 3]),
551
+ ]
552
+ return _cases_from_pool_and_factory(rng, visible_pool, _range_span_hidden_case)
553
+
554
+
555
+ def _range_span_hidden_case(rng: random.Random) -> str:
556
+ length = rng.randint(4, 12)
557
+ numbers = [rng.randint(-60, 60) for _ in range(length)]
558
+ if len(set(numbers)) == 1:
559
+ numbers[-1] += 5
560
+ return _array_case(numbers)
561
+
562
+
563
+ def _build_count_vowels_cases(rng: random.Random) -> list[str]:
564
+ visible_pool = [
565
+ "hello world\n",
566
+ "sky\n",
567
+ "AEIOU\n",
568
+ ]
569
+ return _cases_from_pool_and_factory(rng, visible_pool, _count_vowels_hidden_case)
570
+
571
+
572
+ def _count_vowels_hidden_case(rng: random.Random) -> str:
573
+ word_bank = [
574
+ "algorithm",
575
+ "queue",
576
+ "stack",
577
+ "binary",
578
+ "graph",
579
+ "open env",
580
+ "unit test",
581
+ "dynamic programming",
582
+ "vowel heavy area",
583
+ "crypt rhythm",
584
+ ]
585
+ parts = [rng.choice(word_bank) for _ in range(rng.randint(1, 3))]
586
+ text = " ".join(parts)
587
+ return f"{text[:80]}\n"
588
+
589
+
590
+ def _build_max_consecutive_ones_cases(rng: random.Random) -> list[str]:
591
+ visible_pool = ["1101110\n", "00000\n", "1\n"]
592
+ return _cases_from_pool_and_factory(rng, visible_pool, _max_consecutive_ones_hidden_case)
593
+
594
+
595
+ def _max_consecutive_ones_hidden_case(rng: random.Random) -> str:
596
+ length = rng.randint(8, 40)
597
+ chars = [rng.choice(["0", "1"]) for _ in range(length)]
598
+ if "1" not in chars:
599
+ start = rng.randint(0, max(0, length - 3))
600
+ run_length = rng.randint(1, min(5, length - start))
601
+ for index in range(start, start + run_length):
602
+ chars[index] = "1"
603
+ return f"{''.join(chars)}\n"
604
+
605
+
606
+ def _build_fizzbuzz_variant_cases(rng: random.Random) -> list[str]:
607
+ visible_pool = [
608
+ _fizzbuzz_case(8, 3, 5, "Fizz", "Buzz"),
609
+ _fizzbuzz_case(6, 2, 4, "Hop", "Pop"),
610
+ _fizzbuzz_case(10, 2, 3, "Up", "Go"),
611
+ ]
612
+ return _cases_from_pool_and_factory(rng, visible_pool, _fizzbuzz_hidden_case)
613
+
614
+
615
+ def _fizzbuzz_hidden_case(rng: random.Random) -> str:
616
+ labels = [
617
+ ("Fizz", "Buzz"),
618
+ ("Ping", "Pong"),
619
+ ("Hop", "Skip"),
620
+ ("Alpha", "Beta"),
621
+ ("Red", "Blue"),
622
+ ]
623
+ label_a, label_b = rng.choice(labels)
624
+ a = rng.randint(2, 6)
625
+ b = rng.randint(2, 6)
626
+ while b == a:
627
+ b = rng.randint(2, 6)
628
+ n = rng.randint(10, 25)
629
+ return _fizzbuzz_case(n, a, b, label_a, label_b)
630
+
631
+
632
+ def _build_running_total_cases(rng: random.Random) -> list[str]:
633
+ visible_pool = [
634
+ _array_case([1, 2, 3, 4]),
635
+ _array_case([5, -2, 7]),
636
+ _array_case([0, 0, 1]),
637
+ ]
638
+ return _cases_from_pool_and_factory(rng, visible_pool, _running_total_hidden_case)
639
+
640
+
641
+ def _running_total_hidden_case(rng: random.Random) -> str:
642
+ length = rng.randint(5, 14)
643
+ numbers = [rng.randint(-20, 20) for _ in range(length)]
644
+ return _array_case(numbers)
645
+
646
+
647
+ def _build_peak_cases(rng: random.Random) -> list[str]:
648
+ visible_pool = [
649
+ _array_case([1, 3, 2, 4, 1]),
650
+ _array_case([5, 4, 3, 2, 1]),
651
+ _array_case([2, 5, 1, 5, 2]),
652
+ ]
653
+ return _cases_from_pool_and_factory(rng, visible_pool, _peak_hidden_case)
654
+
655
+
656
+ def _peak_hidden_case(rng: random.Random) -> str:
657
+ length = rng.randint(6, 16)
658
+ numbers = [rng.randint(-20, 20)]
659
+ for _ in range(length - 1):
660
+ numbers.append(numbers[-1] + rng.randint(-8, 8))
661
+ return _array_case(numbers)
662
+
663
+
664
+ def _build_run_cases(rng: random.Random) -> list[str]:
665
+ visible_pool = [
666
+ _array_case([1, 2, 2, 1, 3]),
667
+ _array_case([5, 4, 3, 2]),
668
+ _array_case([1, 1, 1, 1]),
669
+ ]
670
+ return _cases_from_pool_and_factory(rng, visible_pool, _run_hidden_case)
671
+
672
+
673
+ def _run_hidden_case(rng: random.Random) -> str:
674
+ length = rng.randint(6, 18)
675
+ numbers = [rng.randint(-20, 20)]
676
+ for _ in range(length - 1):
677
+ numbers.append(numbers[-1] + rng.randint(-6, 6))
678
+ return _array_case(numbers)
679
+
680
+
681
+ def _build_two_sum_count_cases(rng: random.Random) -> list[str]:
682
+ visible_pool = [
683
+ _target_array_case(5, [1, 2, 3, 4]),
684
+ _target_array_case(2, [1, 1, 1, 1]),
685
+ _target_array_case(0, [-1, 1, 2, -2]),
686
+ ]
687
+ return _cases_from_pool_and_factory(rng, visible_pool, _two_sum_hidden_case)
688
+
689
+
690
+ def _two_sum_hidden_case(rng: random.Random) -> str:
691
+ length = rng.randint(5, 16)
692
+ numbers = [rng.randint(-12, 12) for _ in range(length)]
693
+ target = rng.randint(-10, 10)
694
+ return _target_array_case(target, numbers)
695
+
696
+
697
+ def _build_max_subarray_sum_cases(rng: random.Random) -> list[str]:
698
+ visible_pool = [
699
+ _array_case([1, -2, 3, 4, -1]),
700
+ _array_case([-5, -1, -8]),
701
+ _array_case([2, -1, 2, 3, 4, -5]),
702
  ]
703
+ return _cases_from_pool_and_factory(rng, visible_pool, _max_subarray_hidden_case)
704
+
705
+
706
+ def _max_subarray_hidden_case(rng: random.Random) -> str:
707
+ length = rng.randint(6, 18)
708
+ numbers = [rng.randint(-20, 20) for _ in range(length)]
709
+ return _array_case(numbers)
710
 
711
 
712
+ def _build_group_anagrams_cases(rng: random.Random) -> list[str]:
713
+ visible_pool = [
714
+ _word_list_case(["eat", "tea", "tan", "ate", "nat", "bat"]),
715
+ _word_list_case(["abc", "bca", "cab", "foo"]),
716
+ _word_list_case(["a", "b", "ab", "ba"]),
717
+ ]
718
+ return _cases_from_pool_and_factory(rng, visible_pool, _group_anagrams_hidden_case)
719
+
720
+
721
+ def _group_anagrams_hidden_case(rng: random.Random) -> str:
722
+ base_words = ["stone", "tones", "notes", "silent", "listen", "enlist", "rat", "tar", "art"]
723
+ words: list[str] = []
724
+ for _ in range(rng.randint(4, 10)):
725
+ word = rng.choice(base_words)
726
+ if rng.random() < 0.4:
727
+ shuffled = list(word)
728
+ rng.shuffle(shuffled)
729
+ words.append("".join(shuffled))
730
+ else:
731
+ words.append(word)
732
+ return _word_list_case(words)
733
+
734
+
735
+ def _build_balanced_brackets_cases(rng: random.Random) -> list[str]:
736
+ visible_pool = [
737
+ "([]{})\n",
738
+ "([)]\n",
739
+ "{[()]}\n",
740
+ ]
741
+ return _cases_from_pool_and_factory(rng, visible_pool, _balanced_brackets_hidden_case)
742
+
743
+
744
+ def _balanced_brackets_hidden_case(rng: random.Random) -> str:
745
+ if rng.random() < 0.5:
746
+ return f"{_make_balanced_brackets(rng, rng.randint(3, 10))}\n"
747
+ return f"{_make_unbalanced_brackets(rng, rng.randint(3, 10))}\n"
748
+
749
+
750
+ def _build_matrix_diagonal_sum_cases(rng: random.Random) -> list[str]:
751
+ visible_pool = [
752
+ _matrix_case([[1, 2], [3, 4]]),
753
+ _matrix_case([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
754
+ _matrix_case([[2, 0, 2], [1, 5, 1], [2, 0, 2]]),
755
+ ]
756
+ return _cases_from_pool_and_factory(rng, visible_pool, _matrix_diagonal_hidden_case)
757
+
758
+
759
+ def _matrix_diagonal_hidden_case(rng: random.Random) -> str:
760
+ size = rng.randint(3, 6)
761
+ matrix = [[rng.randint(-9, 9) for _ in range(size)] for _ in range(size)]
762
+ return _matrix_case(matrix)
763
+
764
+
765
+ def _build_frequency_cases(rng: random.Random) -> list[str]:
766
+ visible_pool = [
767
+ _array_case([1, 2, 2, 3, 3, 3]),
768
+ _array_case([4, 4, 1, 1]),
769
+ _array_case([-1, -1, -2, -2, -2, 3]),
770
+ ]
771
+ return _cases_from_pool_and_factory(rng, visible_pool, _frequency_hidden_case)
772
+
773
+
774
+ def _frequency_hidden_case(rng: random.Random) -> str:
775
+ length = rng.randint(8, 20)
776
+ numbers = [rng.randint(-8, 8) for _ in range(length)]
777
+ numbers.extend([rng.choice(numbers), rng.choice(numbers)])
778
+ return _array_case(numbers)
779
+
780
+
781
+ def _build_reverse_word_cases(rng: random.Random) -> list[str]:
782
+ visible_pool = [
783
+ "hello world here\n",
784
+ " graph search tree \n",
785
+ "one\n",
786
+ ]
787
+ return _cases_from_pool_and_factory(rng, visible_pool, _reverse_words_hidden_case)
788
+
789
+
790
+ def _reverse_words_hidden_case(rng: random.Random) -> str:
791
  vocabulary = [
792
  "graph",
793
  "queue",
 
802
  "node",
803
  "edge",
804
  ]
805
+ words = [rng.choice(vocabulary) for _ in range(rng.randint(4, 9))]
806
+ spacer = " " * rng.randint(1, 3)
807
+ prefix = " " * rng.randint(0, 2)
808
+ suffix = " " * rng.randint(0, 2)
809
+ return f"{prefix}{spacer.join(words)}{suffix}\n"
810
+
811
+
812
+ def _build_lcs_cases(rng: random.Random) -> list[str]:
813
+ visible_pool = [
814
+ _two_line_case("abcde", "ace"),
815
+ _two_line_case("abc", "abc"),
816
+ _two_line_case("abc", "def"),
817
+ ]
818
+ return _cases_from_pool_and_factory(rng, visible_pool, _lcs_hidden_case)
819
+
820
+
821
+ def _lcs_hidden_case(rng: random.Random) -> str:
822
+ alphabet = "abcdxyz"
823
+ left = "".join(rng.choice(alphabet) for _ in range(rng.randint(6, 14)))
824
+ right = "".join(rng.choice(alphabet) for _ in range(rng.randint(6, 14)))
825
+ return _two_line_case(left, right)
826
+
827
+
828
+ def _build_word_ladder_cases(rng: random.Random) -> list[str]:
829
+ visible_pool = [
830
+ _word_ladder_case("hit", "cog", ["hot", "dot", "dog", "lot", "log", "cog"]),
831
+ _word_ladder_case("same", "same", ["same", "lame", "came"]),
832
+ _word_ladder_case("cold", "warm", ["cord", "card", "ward", "sold"]),
833
+ ]
834
+ return _cases_from_pool_and_factory(rng, visible_pool, _word_ladder_hidden_case)
835
+
836
+
837
+ def _word_ladder_hidden_case(rng: random.Random) -> str:
838
+ length = rng.randint(3, 5)
839
+ if rng.random() < 0.7:
840
+ path_length = rng.randint(2, 5)
841
+ path = _build_word_ladder_path(rng, length, path_length)
842
+ extras = _build_word_ladder_extras(rng, length, rng.randint(2, 7), set(path))
843
+ words = path[1:] + extras
844
+ rng.shuffle(words)
845
+ return _word_ladder_case(path[0], path[-1], words)
846
+
847
+ start = _random_word(rng, length)
848
+ target = _random_word(rng, length)
849
+ while target == start:
850
+ target = _random_word(rng, length)
851
+ extras = _build_word_ladder_extras(rng, length, rng.randint(4, 10), {start, target})
852
+ return _word_ladder_case(start, target, extras)
853
+
854
+
855
+ def _build_merge_intervals_cases(rng: random.Random) -> list[str]:
856
+ visible_pool = [
857
+ _interval_case([(1, 3), (2, 4), (6, 8)]),
858
+ _interval_case([(1, 2), (3, 4), (5, 6)]),
859
+ _interval_case([(0, 5), (2, 3), (4, 10)]),
860
+ ]
861
+ return _cases_from_pool_and_factory(rng, visible_pool, _merge_intervals_hidden_case)
862
+
863
+
864
+ def _merge_intervals_hidden_case(rng: random.Random) -> str:
865
+ intervals = []
866
+ for _ in range(rng.randint(4, 12)):
867
+ start = rng.randint(-10, 20)
868
+ end = start + rng.randint(0, 8)
869
+ intervals.append((start, end))
870
+ return _interval_case(intervals)
871
+
872
+
873
+ def _build_min_coins_cases(rng: random.Random) -> list[str]:
874
+ visible_pool = [
875
+ _coin_case([1, 3, 4], 6),
876
+ _coin_case([2, 5], 3),
877
+ _coin_case([2, 5, 7], 14),
878
+ ]
879
+ return _cases_from_pool_and_factory(rng, visible_pool, _min_coins_hidden_case)
880
+
881
+
882
+ def _min_coins_hidden_case(rng: random.Random) -> str:
883
+ coin_count = rng.randint(2, 6)
884
+ coins = sorted({rng.randint(1, 10) for _ in range(coin_count + 2)})
885
+ coins = coins[:coin_count]
886
+ target = rng.randint(5, 40)
887
+ return _coin_case(coins, target)
888
+
889
+
890
+ def _build_rotate_matrix_cases(rng: random.Random) -> list[str]:
891
+ visible_pool = [
892
+ _matrix_case([[1, 2], [3, 4]]),
893
+ _matrix_case([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
894
+ _matrix_case([[5, 1], [0, -1]]),
895
+ ]
896
+ return _cases_from_pool_and_factory(rng, visible_pool, _rotate_matrix_hidden_case)
897
+
898
+
899
+ def _rotate_matrix_hidden_case(rng: random.Random) -> str:
900
+ size = rng.randint(2, 5)
901
+ matrix = [[rng.randint(-9, 9) for _ in range(size)] for _ in range(size)]
902
+ return _matrix_case(matrix)
903
+
904
+
905
+ def _cases_from_pool_and_factory(
906
+ rng: random.Random,
907
+ visible_pool: list[str],
908
+ hidden_factory: Callable[[random.Random], str],
909
+ ) -> list[str]:
910
+ cases: list[str] = []
911
+ seen: set[str] = set()
912
+
913
+ for case_input in rng.sample(visible_pool, k=VISIBLE_TEST_COUNT):
914
+ cases.append(case_input)
915
+ seen.add(case_input)
916
+
917
+ attempts = 0
918
+ while len(cases) < TOTAL_TEST_CASES:
919
+ candidate = hidden_factory(rng)
920
+ attempts += 1
921
+ if candidate in seen:
922
+ if attempts > 200:
923
+ raise ValueError("Unable to generate unique test cases.")
924
+ continue
925
+ seen.add(candidate)
926
+ cases.append(candidate)
927
+
928
+ return cases
929
 
930
 
931
  def _array_case(numbers: list[int]) -> str:
932
  return f"{len(numbers)}\n{' '.join(str(number) for number in numbers)}\n"
933
 
934
 
935
+ def _target_array_case(target: int, numbers: list[int]) -> str:
936
+ return f"{len(numbers)} {target}\n{' '.join(str(number) for number in numbers)}\n"
937
+
938
+
939
+ def _word_list_case(words: list[str]) -> str:
940
+ return f"{len(words)}\n{' '.join(words)}\n"
941
+
942
+
943
+ def _matrix_case(matrix: list[list[int]]) -> str:
944
+ rows = [" ".join(str(value) for value in row) for row in matrix]
945
+ return f"{len(matrix)}\n" + "\n".join(rows) + "\n"
946
+
947
+
948
+ def _two_line_case(first: str, second: str) -> str:
949
+ return f"{first}\n{second}\n"
950
+
951
+
952
+ def _interval_case(intervals: list[tuple[int, int]]) -> str:
953
+ rows = [f"{start} {end}" for start, end in intervals]
954
+ return f"{len(intervals)}\n" + "\n".join(rows) + "\n"
955
+
956
+
957
+ def _coin_case(coins: list[int], target: int) -> str:
958
+ return f"{len(coins)} {target}\n{' '.join(str(coin) for coin in coins)}\n"
959
+
960
+
961
+ def _fizzbuzz_case(n: int, a: int, b: int, label_a: str, label_b: str) -> str:
962
+ return f"{n} {a} {b}\n{label_a} {label_b}\n"
963
+
964
+
965
+ def _word_ladder_case(start: str, target: str, words: list[str]) -> str:
966
+ return f"{start} {target}\n{len(words)}\n{' '.join(words)}\n"
967
+
968
+
969
  def _solve_sum_even_numbers(stdin: str) -> str:
970
  _, numbers = _parse_int_array(stdin)
971
  return str(sum(number for number in numbers if number % 2 == 0))
 
976
  return str(max(numbers) - min(numbers))
977
 
978
 
979
+ def _solve_count_vowels(stdin: str) -> str:
980
+ text = stdin.rstrip("\n")
981
+ return str(sum(1 for char in text.lower() if char in "aeiou"))
982
+
983
+
984
+ def _solve_max_consecutive_ones(stdin: str) -> str:
985
+ binary = stdin.strip()
986
+ best = 0
987
+ current = 0
988
+ for char in binary:
989
+ if char == "1":
990
+ current += 1
991
+ best = max(best, current)
992
+ else:
993
+ current = 0
994
+ return str(best)
995
+
996
+
997
+ def _solve_fizzbuzz_variant(stdin: str) -> str:
998
+ (n, a, b), (label_a, label_b) = _parse_fizzbuzz(stdin)
999
+ output = []
1000
+ for value in range(1, n + 1):
1001
+ token = ""
1002
+ if value % a == 0:
1003
+ token += label_a
1004
+ if value % b == 0:
1005
+ token += label_b
1006
+ output.append(token or str(value))
1007
+ return " ".join(output)
1008
+
1009
+
1010
+ def _solve_running_total(stdin: str) -> str:
1011
+ _, numbers = _parse_int_array(stdin)
1012
+ total = 0
1013
+ running = []
1014
+ for number in numbers:
1015
+ total += number
1016
+ running.append(str(total))
1017
+ return " ".join(running)
1018
+
1019
+
1020
  def _solve_count_local_peaks(stdin: str) -> str:
1021
  _, numbers = _parse_int_array(stdin)
1022
  peaks = 0
 
1039
  return str(best)
1040
 
1041
 
1042
+ def _solve_two_sum_count(stdin: str) -> str:
1043
+ _, target, numbers = _parse_target_array(stdin)
1044
+ counts: Counter[int] = Counter()
1045
+ pairs = 0
1046
+ for number in numbers:
1047
+ pairs += counts[target - number]
1048
+ counts[number] += 1
1049
+ return str(pairs)
1050
+
1051
+
1052
+ def _solve_max_subarray_sum(stdin: str) -> str:
1053
+ _, numbers = _parse_int_array(stdin)
1054
+ best = numbers[0]
1055
+ current = numbers[0]
1056
+ for number in numbers[1:]:
1057
+ current = max(number, current + number)
1058
+ best = max(best, current)
1059
+ return str(best)
1060
+
1061
+
1062
+ def _solve_group_anagrams_count(stdin: str) -> str:
1063
+ _, words = _parse_word_list(stdin)
1064
+ groups = {"".join(sorted(word)) for word in words}
1065
+ return str(len(groups))
1066
+
1067
+
1068
+ def _solve_balanced_brackets(stdin: str) -> str:
1069
+ text = stdin.strip()
1070
+ pairs = {")": "(", "]": "[", "}": "{"}
1071
+ stack: list[str] = []
1072
+ for char in text:
1073
+ if char in "([{":
1074
+ stack.append(char)
1075
+ elif char in pairs:
1076
+ if not stack or stack.pop() != pairs[char]:
1077
+ return "NO"
1078
+ return "YES" if not stack else "NO"
1079
+
1080
+
1081
+ def _solve_matrix_diagonal_sum(stdin: str) -> str:
1082
+ _, matrix = _parse_matrix(stdin)
1083
+ total = 0
1084
+ size = len(matrix)
1085
+ for index in range(size):
1086
+ total += matrix[index][index]
1087
+ mirrored = size - 1 - index
1088
+ if mirrored != index:
1089
+ total += matrix[index][mirrored]
1090
+ return str(total)
1091
+
1092
+
1093
  def _solve_smallest_most_frequent(stdin: str) -> str:
1094
  _, numbers = _parse_int_array(stdin)
1095
+ counts = Counter(numbers)
 
 
1096
  best_count = max(counts.values())
1097
  best_value = min(number for number, count in counts.items() if count == best_count)
1098
  return str(best_value)
 
1103
  return " ".join(reversed(words))
1104
 
1105
 
1106
+ def _solve_longest_common_subsequence(stdin: str) -> str:
1107
+ left, right = _parse_two_strings(stdin)
1108
+ dp = [[0] * (len(right) + 1) for _ in range(len(left) + 1)]
1109
+ for i in range(1, len(left) + 1):
1110
+ for j in range(1, len(right) + 1):
1111
+ if left[i - 1] == right[j - 1]:
1112
+ dp[i][j] = dp[i - 1][j - 1] + 1
1113
+ else:
1114
+ dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
1115
+ return str(dp[-1][-1])
1116
+
1117
+
1118
+ def _solve_word_ladder_steps(stdin: str) -> str:
1119
+ start, target, words = _parse_word_ladder(stdin)
1120
+ if start == target:
1121
+ return "0"
1122
+ word_set = set(words)
1123
+ if target not in word_set:
1124
+ return "-1"
1125
+
1126
+ queue: deque[tuple[str, int]] = deque([(start, 0)])
1127
+ visited = {start}
1128
+ alphabet = "abcdefghijklmnopqrstuvwxyz"
1129
+
1130
+ while queue:
1131
+ current, steps = queue.popleft()
1132
+ for index in range(len(current)):
1133
+ for letter in alphabet:
1134
+ if letter == current[index]:
1135
+ continue
1136
+ candidate = current[:index] + letter + current[index + 1 :]
1137
+ if candidate == target:
1138
+ return str(steps + 1)
1139
+ if candidate in word_set and candidate not in visited:
1140
+ visited.add(candidate)
1141
+ queue.append((candidate, steps + 1))
1142
+ return "-1"
1143
+
1144
+
1145
+ def _solve_merge_intervals(stdin: str) -> str:
1146
+ intervals = _parse_intervals(stdin)
1147
+ ordered = sorted(intervals)
1148
+ merged: list[list[int]] = []
1149
+ for start, end in ordered:
1150
+ if not merged or start > merged[-1][1]:
1151
+ merged.append([start, end])
1152
+ else:
1153
+ merged[-1][1] = max(merged[-1][1], end)
1154
+ return str(len(merged))
1155
+
1156
+
1157
+ def _solve_min_coins(stdin: str) -> str:
1158
+ _, target, coins = _parse_coin_problem(stdin)
1159
+ best = [target + 1] * (target + 1)
1160
+ best[0] = 0
1161
+ for value in range(1, target + 1):
1162
+ for coin in coins:
1163
+ if coin <= value:
1164
+ best[value] = min(best[value], best[value - coin] + 1)
1165
+ return str(best[target] if best[target] <= target else -1)
1166
+
1167
+
1168
+ def _solve_rotate_matrix_90(stdin: str) -> str:
1169
+ _, matrix = _parse_matrix(stdin)
1170
+ size = len(matrix)
1171
+ rotated = [[matrix[size - 1 - row][col] for row in range(size)] for col in range(size)]
1172
+ flattened = [str(value) for row in rotated for value in row]
1173
+ return " ".join(flattened)
1174
+
1175
+
1176
  def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
1177
  lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1178
  n = int(lines[0])
 
1180
  if len(numbers) != n:
1181
  raise ValueError(f"Expected {n} integers, received {len(numbers)}")
1182
  return n, numbers
1183
+
1184
+
1185
+ def _parse_target_array(stdin: str) -> tuple[int, int, list[int]]:
1186
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1187
+ n, target = map(int, lines[0].split())
1188
+ numbers = [int(part) for part in lines[1].split()]
1189
+ if len(numbers) != n:
1190
+ raise ValueError(f"Expected {n} integers, received {len(numbers)}")
1191
+ return n, target, numbers
1192
+
1193
+
1194
+ def _parse_word_list(stdin: str) -> tuple[int, list[str]]:
1195
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1196
+ n = int(lines[0])
1197
+ words = lines[1].split()
1198
+ if len(words) != n:
1199
+ raise ValueError(f"Expected {n} words, received {len(words)}")
1200
+ return n, words
1201
+
1202
+
1203
+ def _parse_matrix(stdin: str) -> tuple[int, list[list[int]]]:
1204
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1205
+ n = int(lines[0])
1206
+ matrix = [[int(part) for part in line.split()] for line in lines[1 : n + 1]]
1207
+ if len(matrix) != n or any(len(row) != n for row in matrix):
1208
+ raise ValueError("Matrix dimensions do not match n.")
1209
+ return n, matrix
1210
+
1211
+
1212
+ def _parse_two_strings(stdin: str) -> tuple[str, str]:
1213
+ lines = stdin.strip().splitlines()
1214
+ if len(lines) < 2:
1215
+ raise ValueError("Expected two lines of text.")
1216
+ return lines[0].strip(), lines[1].strip()
1217
+
1218
+
1219
+ def _parse_fizzbuzz(stdin: str) -> tuple[tuple[int, int, int], tuple[str, str]]:
1220
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1221
+ n, a, b = map(int, lines[0].split())
1222
+ label_a, label_b = lines[1].split()
1223
+ return (n, a, b), (label_a, label_b)
1224
+
1225
+
1226
+ def _parse_word_ladder(stdin: str) -> tuple[str, str, list[str]]:
1227
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1228
+ start, target = lines[0].split()
1229
+ n = int(lines[1])
1230
+ words = lines[2].split()
1231
+ if len(words) != n:
1232
+ raise ValueError(f"Expected {n} words, received {len(words)}")
1233
+ return start, target, words
1234
+
1235
+
1236
+ def _parse_intervals(stdin: str) -> list[tuple[int, int]]:
1237
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1238
+ n = int(lines[0])
1239
+ intervals = [tuple(map(int, line.split())) for line in lines[1 : n + 1]]
1240
+ if len(intervals) != n:
1241
+ raise ValueError("Interval count does not match n.")
1242
+ return [(start, end) for start, end in intervals]
1243
+
1244
+
1245
+ def _parse_coin_problem(stdin: str) -> tuple[int, int, list[int]]:
1246
+ lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
1247
+ n, target = map(int, lines[0].split())
1248
+ coins = [int(part) for part in lines[1].split()]
1249
+ if len(coins) != n:
1250
+ raise ValueError(f"Expected {n} coins, received {len(coins)}")
1251
+ return n, target, coins
1252
+
1253
+
1254
+ def _make_balanced_brackets(rng: random.Random, pairs: int) -> str:
1255
+ opens = ["(", "[", "{"]
1256
+ closing = {"(": ")", "[": "]", "{": "}"}
1257
+ stack: list[str] = []
1258
+ output: list[str] = []
1259
+ for _ in range(pairs * 2):
1260
+ can_open = len(stack) < pairs and (not stack or rng.random() < 0.6)
1261
+ if can_open:
1262
+ token = rng.choice(opens)
1263
+ stack.append(token)
1264
+ output.append(token)
1265
+ else:
1266
+ output.append(closing[stack.pop()])
1267
+ while stack:
1268
+ output.append(closing[stack.pop()])
1269
+ return "".join(output)
1270
+
1271
+
1272
+ def _make_unbalanced_brackets(rng: random.Random, pairs: int) -> str:
1273
+ text = list(_make_balanced_brackets(rng, pairs))
1274
+ if not text:
1275
+ return "("
1276
+ mode = rng.choice(["swap", "drop", "flip"])
1277
+ if mode == "swap" and len(text) >= 2:
1278
+ index = rng.randrange(len(text) - 1)
1279
+ text[index], text[index + 1] = text[index + 1], text[index]
1280
+ elif mode == "drop":
1281
+ del text[rng.randrange(len(text))]
1282
+ else:
1283
+ replacements = ["(", ")", "[", "]", "{", "}"]
1284
+ text[rng.randrange(len(text))] = rng.choice(replacements)
1285
+ return "".join(text)
1286
+
1287
+
1288
+ def _random_word(rng: random.Random, length: int) -> str:
1289
+ alphabet = "abcdefghijklmnopqrstuvwxyz"
1290
+ return "".join(rng.choice(alphabet) for _ in range(length))
1291
+
1292
+
1293
+ def _build_word_ladder_path(rng: random.Random, length: int, steps: int) -> list[str]:
1294
+ alphabet = "abcdefghijklmnopqrstuvwxyz"
1295
+ current = _random_word(rng, length)
1296
+ path = [current]
1297
+ used = {current}
1298
+ while len(path) < steps + 1:
1299
+ chars = list(path[-1])
1300
+ index = rng.randrange(length)
1301
+ replacement = rng.choice(alphabet.replace(chars[index], ""))
1302
+ chars[index] = replacement
1303
+ candidate = "".join(chars)
1304
+ if candidate in used:
1305
+ continue
1306
+ used.add(candidate)
1307
+ path.append(candidate)
1308
+ return path
1309
+
1310
+
1311
+ def _build_word_ladder_extras(
1312
+ rng: random.Random,
1313
+ length: int,
1314
+ count: int,
1315
+ disallowed: set[str],
1316
+ ) -> list[str]:
1317
+ words: list[str] = []
1318
+ seen = set(disallowed)
1319
+ while len(words) < count:
1320
+ candidate = _random_word(rng, length)
1321
+ if candidate in seen:
1322
+ continue
1323
+ seen.add(candidate)
1324
+ words.append(candidate)
1325
+ return words
env/test_cases.py CHANGED
@@ -2,9 +2,7 @@ from __future__ import annotations
2
 
3
  from typing import Any
4
 
5
- from env.generator import DIFFICULTY_LABELS, GeneratorAgent
6
-
7
- VISIBLE_TEST_COUNT = 0
8
 
9
 
10
  def load_problem_bank() -> list[dict[str, Any]]:
 
2
 
3
  from typing import Any
4
 
5
+ from env.generator import DIFFICULTY_LABELS, GeneratorAgent, VISIBLE_TEST_COUNT
 
 
6
 
7
 
8
  def load_problem_bank() -> list[dict[str, Any]]:
models.py CHANGED
@@ -18,17 +18,22 @@ except ImportError:
18
  episode_id: str = ""
19
  step_count: int = 0
20
 
21
- from pydantic import Field
22
-
23
 
24
  class AdaptAction(Action):
 
 
 
 
25
  code: str = Field(..., min_length=1, description="Python code to execute.")
26
 
27
 
28
  class AdaptObservation(Observation):
 
29
  problem_id: str = Field(default="", description="Current problem identifier.")
30
  problem_type: str = Field(default="", description="Current generated problem family.")
31
  difficulty: str = Field(default="", description="Current curriculum difficulty tier.")
 
 
32
  problem: str = Field(default="", description="Problem statement shown to the agent.")
33
  input_format: str = Field(default="", description="Expected stdin format.")
34
  constraints: str = Field(default="", description="Problem constraints.")
@@ -48,14 +53,17 @@ class AdaptObservation(Observation):
48
 
49
 
50
  class AdaptState(State):
 
51
  problem_id: str = Field(default="")
52
  problem_type: str = Field(default="")
53
  difficulty: str = Field(default="")
54
  generator_mode: str = Field(default="heuristic")
55
- generated_problem: dict[str, str] = Field(default_factory=dict)
 
56
  last_reward: float = Field(default=0.0)
57
  last_pass_rate: float = Field(default=0.0, ge=0.0, le=1.0)
58
  last_feedback: str = Field(default="")
 
59
  generator_reward_signal: float = Field(default=0.0)
60
  history: dict[str, Any] = Field(default_factory=dict)
61
  recent_metrics: dict[str, Any] = Field(default_factory=dict)
 
18
  episode_id: str = ""
19
  step_count: int = 0
20
 
 
 
21
 
22
  class AdaptAction(Action):
23
+ session_id: str = Field(
24
+ default="",
25
+ description="Environment session id for server-routed calls.",
26
+ )
27
  code: str = Field(..., min_length=1, description="Python code to execute.")
28
 
29
 
30
  class AdaptObservation(Observation):
31
+ session_id: str = Field(default="", description="Session id for the active environment instance.")
32
  problem_id: str = Field(default="", description="Current problem identifier.")
33
  problem_type: str = Field(default="", description="Current generated problem family.")
34
  difficulty: str = Field(default="", description="Current curriculum difficulty tier.")
35
+ attempt_number: int = Field(default=0, ge=0, description="1-indexed attempt number within the episode.")
36
+ max_steps: int = Field(default=3, ge=1, description="Maximum attempts allowed for the episode.")
37
  problem: str = Field(default="", description="Problem statement shown to the agent.")
38
  input_format: str = Field(default="", description="Expected stdin format.")
39
  constraints: str = Field(default="", description="Problem constraints.")
 
53
 
54
 
55
  class AdaptState(State):
56
+ session_id: str = Field(default="")
57
  problem_id: str = Field(default="")
58
  problem_type: str = Field(default="")
59
  difficulty: str = Field(default="")
60
  generator_mode: str = Field(default="heuristic")
61
+ max_steps: int = Field(default=3, ge=1)
62
+ generated_problem: dict[str, Any] = Field(default_factory=dict)
63
  last_reward: float = Field(default=0.0)
64
  last_pass_rate: float = Field(default=0.0, ge=0.0, le=1.0)
65
  last_feedback: str = Field(default="")
66
+ last_execution_status: str = Field(default="ready")
67
  generator_reward_signal: float = Field(default=0.0)
68
  history: dict[str, Any] = Field(default_factory=dict)
69
  recent_metrics: dict[str, Any] = Field(default_factory=dict)
openenv.yaml CHANGED
@@ -1,35 +1,7 @@
1
  spec_version: 1
2
- name: adapt_dsa_tutor
3
- type: space
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 7860
7
- description: "ADAPT: an adversarial DSA tutor environment for RLVR code generation with hidden tests, tiered problems, and anti-hacking reward signals."
8
- version: "0.2.0"
9
-
10
- observation_space:
11
- type: dict
12
- description: "Problem prompt, examples, visible tests, difficulty metadata, reward, pass rates, execution status, and feedback."
13
-
14
- action_space:
15
- type: dict
16
- description: "AdaptAction with a Python code string submitted for stdin/stdout evaluation."
17
-
18
- reward_range: [0.0, 1.0]
19
-
20
- tasks:
21
- - name: easy_double
22
- description: "Easy arithmetic stdin/stdout problem."
23
- difficulty: easy
24
- - name: easy_sum_two
25
- description: "Easy two-integer arithmetic problem."
26
- difficulty: easy
27
- - name: medium_maximum
28
- description: "Medium array scanning problem."
29
- difficulty: medium
30
- - name: medium_count_even
31
- description: "Medium counting problem over a list."
32
- difficulty: medium
33
- - name: hard_reverse_words
34
- description: "Harder string normalization and ordering problem."
35
- difficulty: hard
 
1
  spec_version: 1
2
+ name: adapt-dsa-tutor
3
+ version: "0.3.0"
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 7860
7
+ description: "Adversarial DSA Programming Tutor - RL environment for training LLMs to solve algorithmic problems through adaptive curriculum and self-repair"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/test_env.py CHANGED
@@ -7,7 +7,7 @@ ROOT = Path(__file__).resolve().parents[1]
7
  if str(ROOT) not in sys.path:
8
  sys.path.insert(0, str(ROOT))
9
 
10
- from env.adapt_env import AdaptEnvironment
11
  from env.generator import GeneratorAgent
12
  from models import AdaptAction
13
 
@@ -23,10 +23,12 @@ def main() -> None:
23
  env = AdaptEnvironment(generator=GeneratorAgent())
24
  observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
25
  assert observation.problem
 
26
  assert observation.input_format
27
  assert observation.constraints
28
  assert observation.problem_type == "sum_even_numbers"
29
  assert observation.execution_status == "ready"
 
30
  assert_hidden_tests_are_not_exposed(observation.model_dump())
31
 
32
  correct = env.step(
@@ -39,11 +41,13 @@ def main() -> None:
39
  )
40
  )
41
  print(correct)
42
- assert correct.reward > 0.8, correct.model_dump()
43
  assert correct.pass_rate == 1.0
44
  assert correct.execution_status == "completed"
 
45
 
46
- wrong = env.step(
 
47
  AdaptAction(
48
  code=(
49
  "n=int(input())\n"
@@ -52,41 +56,64 @@ def main() -> None:
52
  )
53
  )
54
  )
55
- print(wrong)
56
- assert 0.0 <= float(wrong.reward) < 1.0
57
- assert wrong.execution_status in {"wrong_answer", "completed"}
58
- assert wrong.pass_rate < 1.0
59
 
60
- invalid_output = env.step(
61
  AdaptAction(
62
  code=(
63
  "n=int(input())\n"
64
- "input()\n"
65
- "print()"
 
 
 
 
 
66
  )
67
  )
68
  )
69
- print(invalid_output)
70
- assert invalid_output.invalid_output_count > 0
71
- assert invalid_output.execution_status == "invalid_output_format"
 
 
72
 
 
73
  syntax = env.step(AdaptAction(code="def broken(:\n pass"))
74
  print(syntax)
75
  assert syntax.reward == 0.0
 
76
  assert syntax.execution_status == "syntax_error"
77
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  timeout = env.step(AdaptAction(code="while True:\n pass"))
79
  print(timeout)
80
  assert timeout.timeout_count > 0
81
  assert timeout.execution_status == "timeout"
 
82
 
 
83
  unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
84
  print(unsafe)
85
  assert unsafe.reward == 0.0
86
  assert unsafe.execution_status == "safety_violation"
 
87
 
88
- assert env.state.step_count == 6
89
- assert env.state.history["recent_pass_rates"]
90
  assert_hidden_tests_are_not_exposed(timeout.model_dump())
91
  print("ADAPT OpenEnv smoke tests passed")
92
 
 
7
  if str(ROOT) not in sys.path:
8
  sys.path.insert(0, str(ROOT))
9
 
10
+ from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
11
  from env.generator import GeneratorAgent
12
  from models import AdaptAction
13
 
 
23
  env = AdaptEnvironment(generator=GeneratorAgent())
24
  observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
25
  assert observation.problem
26
+ assert "Examples:" in observation.problem
27
  assert observation.input_format
28
  assert observation.constraints
29
  assert observation.problem_type == "sum_even_numbers"
30
  assert observation.execution_status == "ready"
31
+ assert observation.max_steps == MAX_STEPS_PER_EPISODE
32
  assert_hidden_tests_are_not_exposed(observation.model_dump())
33
 
34
  correct = env.step(
 
41
  )
42
  )
43
  print(correct)
44
+ assert correct.reward == 1.0, correct.model_dump()
45
  assert correct.pass_rate == 1.0
46
  assert correct.execution_status == "completed"
47
+ assert correct.done is True
48
 
49
+ observation = env.reset(problem_id="running_total", difficulty="easy")
50
+ repair_1 = env.step(
51
  AdaptAction(
52
  code=(
53
  "n=int(input())\n"
 
56
  )
57
  )
58
  )
59
+ print(repair_1)
60
+ assert repair_1.done is False
61
+ assert repair_1.execution_status in {"wrong_answer", "runtime_error", "invalid_output_format"}
62
+ assert "Previous attempt status: ready" in repair_1.feedback
63
 
64
+ repair_2 = env.step(
65
  AdaptAction(
66
  code=(
67
  "n=int(input())\n"
68
+ "nums=list(map(int,input().split()))\n"
69
+ "running=0\n"
70
+ "out=[]\n"
71
+ "for x in nums:\n"
72
+ " running += x\n"
73
+ " out.append(str(running))\n"
74
+ "print(' '.join(out))"
75
  )
76
  )
77
  )
78
+ print(repair_2)
79
+ assert repair_2.done is True
80
+ assert repair_2.pass_rate == 1.0
81
+ assert repair_2.reward == 0.85
82
+ assert "Previous attempt status:" in repair_2.feedback
83
 
84
+ observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
85
  syntax = env.step(AdaptAction(code="def broken(:\n pass"))
86
  print(syntax)
87
  assert syntax.reward == 0.0
88
+ assert syntax.done is False
89
  assert syntax.execution_status == "syntax_error"
90
 
91
+ runtime = env.step(
92
+ AdaptAction(
93
+ code=(
94
+ "n=int(input())\n"
95
+ "nums=list(map(int,input().split()))\n"
96
+ "print(nums[n])"
97
+ )
98
+ )
99
+ )
100
+ print(runtime)
101
+ assert runtime.execution_status == "runtime_error"
102
+
103
  timeout = env.step(AdaptAction(code="while True:\n pass"))
104
  print(timeout)
105
  assert timeout.timeout_count > 0
106
  assert timeout.execution_status == "timeout"
107
+ assert timeout.done is True
108
 
109
+ observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
110
  unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
111
  print(unsafe)
112
  assert unsafe.reward == 0.0
113
  assert unsafe.execution_status == "safety_violation"
114
+ assert unsafe.done is False
115
 
116
+ assert env.state.history["attempts"]
 
117
  assert_hidden_tests_are_not_exposed(timeout.model_dump())
118
  print("ADAPT OpenEnv smoke tests passed")
119
 
server/app.py CHANGED
@@ -1,10 +1,12 @@
1
  from __future__ import annotations
2
 
3
  import argparse
 
4
  from typing import Any
 
5
 
6
  import uvicorn
7
- from fastapi import Body, FastAPI, HTTPException, Request
8
  from fastapi.responses import RedirectResponse, Response
9
  from pydantic import BaseModel
10
 
@@ -12,11 +14,15 @@ from env.adapt_env import AdaptEnvironment
12
  from env.test_cases import load_problem_bank
13
  from models import AdaptAction, AdaptObservation, AdaptState
14
 
15
- ENV_NAME = "adapt_dsa_tutor"
16
  ENV_DESCRIPTION = (
17
- "RL environment for DSA code generation with hidden tests, tiered problems, "
18
- "and verifier-aware reward shaping."
19
  )
 
 
 
 
20
  TASKS = [
21
  {
22
  "name": problem["problem_id"],
@@ -26,11 +32,11 @@ TASKS = [
26
  for problem in load_problem_bank()
27
  ]
28
 
29
- app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version="0.2.0")
30
- ENV = AdaptEnvironment()
31
 
32
 
33
  class ResetRequest(BaseModel):
 
34
  seed: int | None = None
35
  episode_id: str | None = None
36
  problem_id: str | None = None
@@ -41,16 +47,47 @@ def _metadata() -> dict[str, Any]:
41
  return {
42
  "name": ENV_NAME,
43
  "description": ENV_DESCRIPTION,
44
- "version": "0.2.0",
45
  "tasks": TASKS,
46
  "mode": "simulation",
47
  }
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @app.get("/")
51
  def root() -> dict[str, Any]:
 
52
  payload = _metadata()
53
  payload["status"] = "ok"
 
54
  return payload
55
 
56
 
@@ -70,22 +107,26 @@ def favicon() -> Response:
70
 
71
 
72
  @app.get("/health")
73
- def health() -> dict[str, str]:
74
- return {"status": "healthy"}
 
75
 
76
 
77
  @app.get("/metadata")
78
  def metadata() -> dict[str, Any]:
 
79
  return _metadata()
80
 
81
 
82
  @app.get("/tasks")
83
  def list_tasks() -> dict[str, Any]:
 
84
  return {"tasks": TASKS}
85
 
86
 
87
  @app.get("/schema")
88
  def schema() -> dict[str, Any]:
 
89
  return {
90
  "action": AdaptAction.model_json_schema(),
91
  "observation": AdaptObservation.model_json_schema(),
@@ -95,6 +136,7 @@ def schema() -> dict[str, Any]:
95
 
96
  @app.post("/mcp")
97
  def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
 
98
  return {
99
  "jsonrpc": "2.0",
100
  "id": payload.get("id"),
@@ -107,8 +149,14 @@ def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
107
 
108
  @app.post("/reset")
109
  def reset(request: ResetRequest | None = None) -> dict[str, Any]:
 
110
  effective_request = request or ResetRequest()
111
- observation = ENV.reset(
 
 
 
 
 
112
  seed=effective_request.seed,
113
  episode_id=effective_request.episode_id,
114
  problem_id=effective_request.problem_id,
@@ -119,6 +167,7 @@ def reset(request: ResetRequest | None = None) -> dict[str, Any]:
119
 
120
  @app.post("/step")
121
  async def step(request: Request) -> dict[str, Any]:
 
122
  payload = await request.json()
123
  if not isinstance(payload, dict):
124
  raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
@@ -129,24 +178,31 @@ async def step(request: Request) -> dict[str, Any]:
129
  except Exception as exc:
130
  raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
131
 
132
- observation = ENV.step(effective_action)
 
 
 
 
133
  return {
134
  "observation": observation.model_dump(),
135
  "reward": float(observation.reward),
136
  "done": bool(observation.done),
137
  "info": {
 
138
  "feedback": observation.feedback,
139
  "pass_rate": observation.pass_rate,
 
140
  "execution_status": observation.execution_status,
141
  },
142
  }
143
 
144
 
145
  @app.get("/state")
146
- def state() -> dict[str, Any]:
147
- if not ENV.problem:
148
- ENV.reset()
149
- return ENV.state.model_dump()
 
150
 
151
 
152
  def main(host: str | None = None, port: int | None = None) -> None:
 
1
  from __future__ import annotations
2
 
3
  import argparse
4
+ from datetime import datetime, timedelta, timezone
5
  from typing import Any
6
+ from uuid import uuid4
7
 
8
  import uvicorn
9
+ from fastapi import Body, FastAPI, HTTPException, Query, Request
10
  from fastapi.responses import RedirectResponse, Response
11
  from pydantic import BaseModel
12
 
 
14
  from env.test_cases import load_problem_bank
15
  from models import AdaptAction, AdaptObservation, AdaptState
16
 
17
+ ENV_NAME = "adapt-dsa-tutor"
18
  ENV_DESCRIPTION = (
19
+ "Adversarial DSA Programming Tutor - RL environment for training LLMs to solve "
20
+ "algorithmic problems through adaptive curriculum and self-repair."
21
  )
22
+ ENV_VERSION = "0.3.0"
23
+ SESSION_TTL = timedelta(minutes=30)
24
+ SESSIONS: dict[str, AdaptEnvironment] = {}
25
+ SESSION_LAST_ACCESSED: dict[str, datetime] = {}
26
  TASKS = [
27
  {
28
  "name": problem["problem_id"],
 
32
  for problem in load_problem_bank()
33
  ]
34
 
35
+ app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version=ENV_VERSION)
 
36
 
37
 
38
  class ResetRequest(BaseModel):
39
+ session_id: str | None = None
40
  seed: int | None = None
41
  episode_id: str | None = None
42
  problem_id: str | None = None
 
47
  return {
48
  "name": ENV_NAME,
49
  "description": ENV_DESCRIPTION,
50
+ "version": ENV_VERSION,
51
  "tasks": TASKS,
52
  "mode": "simulation",
53
  }
54
 
55
 
56
+ def _utc_now() -> datetime:
57
+ return datetime.now(timezone.utc)
58
+
59
+
60
+ def _cleanup_sessions() -> None:
61
+ now = _utc_now()
62
+ expired = [
63
+ session_id
64
+ for session_id, last_seen in SESSION_LAST_ACCESSED.items()
65
+ if now - last_seen > SESSION_TTL
66
+ ]
67
+ for session_id in expired:
68
+ SESSIONS.pop(session_id, None)
69
+ SESSION_LAST_ACCESSED.pop(session_id, None)
70
+
71
+
72
+ def _touch_session(session_id: str) -> None:
73
+ SESSION_LAST_ACCESSED[session_id] = _utc_now()
74
+
75
+
76
+ def _require_session(session_id: str) -> AdaptEnvironment:
77
+ _cleanup_sessions()
78
+ env = SESSIONS.get(session_id)
79
+ if env is None:
80
+ raise HTTPException(status_code=404, detail=f"Unknown or expired session_id: {session_id}")
81
+ _touch_session(session_id)
82
+ return env
83
+
84
+
85
  @app.get("/")
86
  def root() -> dict[str, Any]:
87
+ _cleanup_sessions()
88
  payload = _metadata()
89
  payload["status"] = "ok"
90
+ payload["active_sessions"] = len(SESSIONS)
91
  return payload
92
 
93
 
 
107
 
108
 
109
  @app.get("/health")
110
+ def health() -> dict[str, Any]:
111
+ _cleanup_sessions()
112
+ return {"status": "healthy", "active_sessions": len(SESSIONS)}
113
 
114
 
115
  @app.get("/metadata")
116
  def metadata() -> dict[str, Any]:
117
+ _cleanup_sessions()
118
  return _metadata()
119
 
120
 
121
  @app.get("/tasks")
122
  def list_tasks() -> dict[str, Any]:
123
+ _cleanup_sessions()
124
  return {"tasks": TASKS}
125
 
126
 
127
  @app.get("/schema")
128
  def schema() -> dict[str, Any]:
129
+ _cleanup_sessions()
130
  return {
131
  "action": AdaptAction.model_json_schema(),
132
  "observation": AdaptObservation.model_json_schema(),
 
136
 
137
  @app.post("/mcp")
138
  def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
139
+ _cleanup_sessions()
140
  return {
141
  "jsonrpc": "2.0",
142
  "id": payload.get("id"),
 
149
 
150
  @app.post("/reset")
151
  def reset(request: ResetRequest | None = None) -> dict[str, Any]:
152
+ _cleanup_sessions()
153
  effective_request = request or ResetRequest()
154
+ session_id = effective_request.session_id or str(uuid4())
155
+ env = AdaptEnvironment(session_id=session_id)
156
+ SESSIONS[session_id] = env
157
+ _touch_session(session_id)
158
+ observation = env.reset(
159
+ session_id=session_id,
160
  seed=effective_request.seed,
161
  episode_id=effective_request.episode_id,
162
  problem_id=effective_request.problem_id,
 
167
 
168
  @app.post("/step")
169
  async def step(request: Request) -> dict[str, Any]:
170
+ _cleanup_sessions()
171
  payload = await request.json()
172
  if not isinstance(payload, dict):
173
  raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
 
178
  except Exception as exc:
179
  raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
180
 
181
+ if not effective_action.session_id:
182
+ raise HTTPException(status_code=422, detail="`session_id` is required in the /step request body.")
183
+
184
+ env = _require_session(effective_action.session_id)
185
+ observation = env.step(effective_action)
186
  return {
187
  "observation": observation.model_dump(),
188
  "reward": float(observation.reward),
189
  "done": bool(observation.done),
190
  "info": {
191
+ "session_id": observation.session_id,
192
  "feedback": observation.feedback,
193
  "pass_rate": observation.pass_rate,
194
+ "visible_pass_rate": observation.visible_pass_rate,
195
  "execution_status": observation.execution_status,
196
  },
197
  }
198
 
199
 
200
  @app.get("/state")
201
+ def state(session_id: str = Query(..., description="Session id returned from /reset.")) -> dict[str, Any]:
202
+ env = _require_session(session_id)
203
+ if not env.problem:
204
+ env.reset(session_id=session_id)
205
+ return env.state.model_dump()
206
 
207
 
208
  def main(host: str | None = None, port: int | None = None) -> None:
training/plot_results.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ def read_rows(csv_path: Path) -> list[dict[str, Any]]:
11
+ with csv_path.open("r", encoding="utf-8", newline="") as handle:
12
+ reader = csv.DictReader(handle)
13
+ rows: list[dict[str, Any]] = []
14
+ for row in reader:
15
+ parsed: dict[str, Any] = {}
16
+ for key, value in row.items():
17
+ if value is None:
18
+ parsed[key] = value
19
+ continue
20
+ value = value.strip()
21
+ if value == "":
22
+ parsed[key] = value
23
+ continue
24
+ try:
25
+ parsed[key] = float(value) if "." in value else int(value)
26
+ except ValueError:
27
+ parsed[key] = value
28
+ rows.append(parsed)
29
+ return rows
30
+
31
+
32
+ def rolling_mean(values: list[float], window: int) -> list[float]:
33
+ output: list[float] = []
34
+ for index in range(len(values)):
35
+ start = max(0, index - window + 1)
36
+ chunk = values[start : index + 1]
37
+ output.append(sum(chunk) / len(chunk))
38
+ return output
39
+
40
+
41
+ def plot_reward_curve(rows: list[dict[str, Any]], output_dir: Path) -> None:
42
+ import matplotlib.pyplot as plt
43
+
44
+ train_rows = [row for row in rows if row.get("phase") == "train"]
45
+ steps = [int(row["step"]) for row in train_rows]
46
+ rewards = [float(row["episode_reward"]) for row in train_rows]
47
+ reward_smooth = rolling_mean(rewards, window=20)
48
+
49
+ plt.figure(figsize=(10, 5))
50
+ plt.plot(steps, rewards, alpha=0.25, label="Episode reward")
51
+ plt.plot(steps, reward_smooth, linewidth=2, label="20-step moving average")
52
+ plt.xlabel("Training step")
53
+ plt.ylabel("Reward")
54
+ plt.title("ADAPT Training Reward Curve")
55
+ plt.legend()
56
+ plt.tight_layout()
57
+ plt.savefig(output_dir / "reward_curve.png", dpi=200)
58
+ plt.close()
59
+
60
+
61
+ def plot_pass_rate_by_difficulty(rows: list[dict[str, Any]], output_dir: Path) -> None:
62
+ import matplotlib.pyplot as plt
63
+
64
+ train_rows = [row for row in rows if row.get("phase") == "train"]
65
+ grouped: dict[str, list[tuple[int, float]]] = defaultdict(list)
66
+ for row in train_rows:
67
+ grouped[str(row["difficulty_tier"])].append((int(row["step"]), float(row["pass_rate"])))
68
+
69
+ plt.figure(figsize=(10, 5))
70
+ for difficulty in ("easy", "medium", "hard"):
71
+ points = grouped.get(difficulty, [])
72
+ if not points:
73
+ continue
74
+ steps = [step for step, _ in points]
75
+ values = [value for _, value in points]
76
+ smooth = rolling_mean(values, window=10)
77
+ plt.plot(steps, smooth, linewidth=2, label=difficulty.title())
78
+
79
+ plt.xlabel("Training step")
80
+ plt.ylabel("Pass rate")
81
+ plt.title("Pass Rate by Difficulty Tier")
82
+ plt.legend()
83
+ plt.tight_layout()
84
+ plt.savefig(output_dir / "pass_rate_by_difficulty.png", dpi=200)
85
+ plt.close()
86
+
87
+
88
+ def plot_family_productivity(rows: list[dict[str, Any]], output_dir: Path) -> None:
89
+ import matplotlib.pyplot as plt
90
+
91
+ train_rows = [row for row in rows if row.get("phase") == "train"]
92
+ productivity_columns = [key for key in train_rows[0].keys() if str(key).startswith("family_productivity__")]
93
+ if not productivity_columns:
94
+ return
95
+
96
+ ranked_columns = sorted(
97
+ productivity_columns,
98
+ key=lambda column: float(train_rows[-1].get(column, 0.0)),
99
+ reverse=True,
100
+ )[:8]
101
+
102
+ plt.figure(figsize=(11, 6))
103
+ steps = [int(row["step"]) for row in train_rows]
104
+ for column in ranked_columns:
105
+ family = column.split("__", 1)[1]
106
+ values = [float(row.get(column, 0.0)) for row in train_rows]
107
+ plt.plot(steps, values, linewidth=2, label=family)
108
+
109
+ plt.xlabel("Training step")
110
+ plt.ylabel("Family productivity EMA")
111
+ plt.title("Reward-Aware Family Productivity Over Training")
112
+ plt.legend(loc="upper left", fontsize=8)
113
+ plt.tight_layout()
114
+ plt.savefig(output_dir / "family_productivity.png", dpi=200)
115
+ plt.close()
116
+
117
+
118
+ def main(argv: list[str] | None = None) -> None:
119
+ parser = argparse.ArgumentParser(description="Plot ADAPT reward and curriculum artifacts from reward_curve.csv.")
120
+ parser.add_argument("csv_path", help="Path to reward_curve.csv")
121
+ parser.add_argument("--output-dir", default=None, help="Directory for PNG outputs. Defaults to the CSV directory.")
122
+ args = parser.parse_args(argv)
123
+
124
+ csv_path = Path(args.csv_path)
125
+ output_dir = Path(args.output_dir) if args.output_dir else csv_path.parent
126
+ output_dir.mkdir(parents=True, exist_ok=True)
127
+
128
+ rows = read_rows(csv_path)
129
+ if not rows:
130
+ raise RuntimeError(f"No rows found in {csv_path}")
131
+
132
+ plot_reward_curve(rows, output_dir)
133
+ plot_pass_rate_by_difficulty(rows, output_dir)
134
+ plot_family_productivity(rows, output_dir)
135
+ print(f"Saved plots to {output_dir}")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
training/train_grpo.py CHANGED
@@ -1,14 +1,25 @@
1
  from __future__ import annotations
2
 
3
  import argparse
 
4
  import json
 
5
  from dataclasses import dataclass, field
 
6
  from typing import Any
7
 
8
- from env.adapt_env import AdaptEnvironment
9
- from env.generator import GeneratorAgent
 
 
10
  from models import AdaptAction
11
 
 
 
 
 
 
 
12
 
13
  def extract_code(completion: str) -> str:
14
  text = completion.strip()
@@ -19,21 +30,45 @@ def extract_code(completion: str) -> str:
19
  return text
20
 
21
 
22
- def build_solver_prompt(problem: dict[str, Any]) -> str:
23
- public_problem = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  "problem_id": problem["problem_id"],
 
25
  "difficulty": problem["difficulty_label"],
26
- "problem": problem["problem"],
 
 
27
  "input_format": problem["input_format"],
28
  "constraints": problem["constraints"],
 
29
  }
30
- return (
31
- "You are the Solver Agent for ADAPT.\n"
32
- "Read the generated DSA task and reply with only runnable Python code.\n"
33
- "The program must read from stdin and print to stdout.\n"
34
- "No markdown, no explanation.\n\n"
35
- f"{json.dumps(public_problem, indent=2)}"
36
- )
37
 
38
 
39
  @dataclass
@@ -42,29 +77,37 @@ class CurriculumManager:
42
  current_idx: int = 0
43
  success_history: list[float] = field(default_factory=list)
44
  window_size: int = 10
 
 
45
 
46
  def current_difficulty(self) -> str:
47
  return self.difficulties[self.current_idx]
48
 
49
- def update(self, batch_success_rate: float) -> None:
50
- self.success_history.append(float(batch_success_rate))
 
 
 
51
  if len(self.success_history) > self.window_size:
52
  self.success_history.pop(0)
53
 
 
 
 
54
  moving_average = sum(self.success_history) / len(self.success_history)
55
- if moving_average > 0.70 and self.current_idx < len(self.difficulties) - 1:
56
  self.current_idx += 1
57
  self.success_history.clear()
58
  print(
59
  f"[curriculum] promoted to {self.current_difficulty()} "
60
- f"(moving_success={moving_average:.2f})"
61
  )
62
- elif moving_average < 0.25 and self.current_idx > 0:
63
  self.current_idx -= 1
64
  self.success_history.clear()
65
  print(
66
- f"[curriculum] reduced to {self.current_difficulty()} "
67
- f"(moving_success={moving_average:.2f})"
68
  )
69
 
70
 
@@ -72,6 +115,7 @@ class CurriculumManager:
72
  class GeneratorController:
73
  mode: str = "heuristic"
74
  deterministic: bool = True
 
75
  generator: GeneratorAgent = field(init=False)
76
  history: dict[str, Any] = field(
77
  default_factory=lambda: {
@@ -83,35 +127,80 @@ class GeneratorController:
83
  }
84
  )
85
  prompt_registry: dict[str, dict[str, Any]] = field(default_factory=dict)
 
86
 
87
  def __post_init__(self) -> None:
88
  self.generator = GeneratorAgent(deterministic=self.deterministic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def create_rollout_problem(self, difficulty: str) -> tuple[str, dict[str, Any]]:
91
- problem = self.generator.generate(difficulty, self.history)
92
- prompt = build_solver_prompt(problem)
93
  self.prompt_registry[prompt] = problem
94
  return prompt, problem
95
 
96
  def resolve_prompt(self, prompt: str) -> dict[str, Any]:
97
  if prompt not in self.prompt_registry:
98
  raise KeyError("Prompt was not registered with the generator controller.")
99
- return self.prompt_registry[prompt]
100
-
101
- def update(self, problem: dict[str, Any], pass_rate: float, generator_reward_signal: float) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
103
  self.history["problem_types"].append(problem.get("problem_type", ""))
104
  self.history["problem_signatures"].append(problem.get("problem_id", ""))
105
- if self.mode == "reward_aware":
106
- self.history["generator_rewards"].append(round(float(generator_reward_signal), 4))
107
- else:
108
- self.history["generator_rewards"].append(0.0)
109
  self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
110
 
 
 
 
 
 
 
111
  for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
112
  values = self.history[key]
113
- if len(values) > 50:
114
- del values[:-50]
115
 
116
  def stats_snapshot(self) -> dict[str, Any]:
117
  return {
@@ -120,6 +209,13 @@ class GeneratorController:
120
  "recent_pass_rates": list(self.history["recent_pass_rates"][-5:]),
121
  "recent_problem_types": list(self.history["problem_types"][-5:]),
122
  "recent_generator_rewards": list(self.history["generator_rewards"][-5:]),
 
 
 
 
 
 
 
123
  }
124
 
125
 
@@ -138,50 +234,273 @@ class GeneratorRolloutDataset:
138
  return {"prompt": prompt}
139
 
140
 
141
- def build_reward_func(curriculum: CurriculumManager, controller: GeneratorController):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def reward_func(prompts, completions, **kwargs) -> list[float]:
143
  del kwargs
144
- env = AdaptEnvironment(generator=controller.generator, generator_mode=controller.mode)
145
  rewards: list[float] = []
146
- pass_rates: list[float] = []
147
 
148
  for prompt, completion in zip(prompts, completions):
149
  problem = controller.resolve_prompt(prompt)
 
150
  env.reset(
151
  difficulty=problem["difficulty_label"],
152
  generated_problem=problem,
153
  generator_mode=controller.mode,
 
 
 
 
 
 
 
154
  )
155
- observation = env.step(AdaptAction(code=extract_code(completion)))
156
  rewards.append(float(observation.reward))
157
- pass_rates.append(float(observation.pass_rate))
158
- controller.update(problem, observation.pass_rate, observation.generator_reward_signal)
159
- print(
160
- "[rollout]",
161
- json.dumps(
162
- {
163
- "problem_id": problem["problem_id"],
164
- "problem_type": problem["problem_type"],
165
- "difficulty": problem["difficulty_label"],
166
- "solver_reward": observation.reward,
167
- "pass_rate": observation.pass_rate,
168
- "generator_reward": observation.generator_reward_signal,
169
- "status": observation.execution_status,
170
- }
171
- ),
172
  )
173
-
174
- if pass_rates:
175
- curriculum.update(sum(pass_rates) / len(pass_rates))
176
- print("[generator]", json.dumps(controller.stats_snapshot()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  return rewards
179
 
180
  return reward_func
181
 
182
 
183
- def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
184
- return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  def run_training(args: argparse.Namespace) -> None:
@@ -193,6 +512,9 @@ def run_training(args: argparse.Namespace) -> None:
193
  "Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
194
  ) from exc
195
 
 
 
 
196
  PatchFastRL("GRPO", FastLanguageModel)
197
 
198
  model, tokenizer = FastLanguageModel.from_pretrained(
@@ -200,6 +522,8 @@ def run_training(args: argparse.Namespace) -> None:
200
  max_seq_length=args.max_seq_length,
201
  load_in_4bit=not args.disable_4bit,
202
  )
 
 
203
 
204
  model = FastLanguageModel.get_peft_model(
205
  model,
@@ -214,6 +538,31 @@ def run_training(args: argparse.Namespace) -> None:
214
  mode="reward_aware" if args.generator_mode == "reward_aware" else "heuristic",
215
  deterministic=not args.non_deterministic_generator,
216
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  training_args = GRPOConfig(
218
  output_dir=args.output_dir,
219
  learning_rate=args.learning_rate,
@@ -225,41 +574,68 @@ def run_training(args: argparse.Namespace) -> None:
225
  max_steps=args.max_steps,
226
  logging_steps=1,
227
  bf16=args.bf16,
 
228
  )
229
 
230
  trainer = GRPOTrainer(
231
  model=model,
232
- reward_funcs=[build_reward_func(curriculum, controller)],
233
  args=training_args,
234
  train_dataset=build_dataset(args.dataset_size, controller, curriculum),
235
  )
236
  trainer.train()
 
237
  model.save_pretrained(args.output_dir)
238
  tokenizer.save_pretrained(args.output_dir)
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def build_parser() -> argparse.ArgumentParser:
242
  parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
243
  parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
244
- parser.add_argument("--output-dir", default="outputs_v2")
245
  parser.add_argument("--dataset-size", type=int, default=200)
246
  parser.add_argument("--max-steps", type=int, default=250)
247
  parser.add_argument("--batch-size", type=int, default=1)
248
  parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
249
  parser.add_argument("--num-generations", type=int, default=8)
250
  parser.add_argument("--max-seq-length", type=int, default=2048)
251
- parser.add_argument("--max-prompt-length", type=int, default=768)
252
  parser.add_argument("--max-completion-length", type=int, default=512)
253
  parser.add_argument("--learning-rate", type=float, default=5e-6)
254
  parser.add_argument("--lora-rank", type=int, default=16)
255
  parser.add_argument("--lora-alpha", type=int, default=16)
256
  parser.add_argument("--disable-4bit", action="store_true")
257
  parser.add_argument("--bf16", action="store_true")
 
 
 
 
 
 
258
  parser.add_argument(
259
  "--generator-mode",
260
  choices=["heuristic", "reward_aware"],
261
- default="heuristic",
262
- help="Use heuristic generation (V1/V2) or reward-aware bookkeeping for V3-ready training.",
263
  )
264
  parser.add_argument(
265
  "--non-deterministic-generator",
 
1
  from __future__ import annotations
2
 
3
  import argparse
4
+ import csv
5
  import json
6
+ import math
7
  from dataclasses import dataclass, field
8
+ from pathlib import Path
9
  from typing import Any
10
 
11
+ import torch
12
+
13
+ from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
14
+ from env.generator import DIFFICULTY_LABELS, GeneratorAgent
15
  from models import AdaptAction
16
 
17
+ SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
18
+ Write only runnable Python code.
19
+ The program must read from stdin and print to stdout.
20
+ If feedback is present, repair your previous solution instead of starting from scratch.
21
+ Do not include markdown fences or explanations."""
22
+
23
 
24
  def extract_code(completion: str) -> str:
25
  text = completion.strip()
 
30
  return text
31
 
32
 
33
+ def format_examples(problem: dict[str, Any]) -> str:
34
+ visible_cases = [test_case for test_case in problem.get("test_cases", []) if test_case.get("is_visible", False)]
35
+ if not visible_cases:
36
+ return problem["problem"]
37
+
38
+ chunks = []
39
+ for test_case in visible_cases:
40
+ chunks.append(f"Input:\n{test_case['input']}Expected Output:\n{test_case['output']}\n")
41
+ return f"{problem['problem']}\n\nExamples:\n" + "\n".join(chunks).rstrip()
42
+
43
+
44
+ def build_solver_prompt(payload: dict[str, Any]) -> str:
45
+ feedback = payload.get("feedback") or "No previous attempt yet."
46
+ return (
47
+ f"{SYSTEM_PROMPT}\n\n"
48
+ f"Problem ID: {payload['problem_id']}\n"
49
+ f"Problem Family: {payload['problem_type']}\n"
50
+ f"Difficulty: {payload['difficulty']}\n"
51
+ f"Attempt: {payload.get('attempt_number', 0)}/{payload.get('max_steps', MAX_STEPS_PER_EPISODE)}\n\n"
52
+ f"Problem:\n{payload['problem']}\n\n"
53
+ f"Input Format:\n{payload['input_format']}\n\n"
54
+ f"Constraints:\n{payload['constraints']}\n\n"
55
+ f"Feedback:\n{feedback}\n"
56
+ )
57
+
58
+
59
+ def build_prompt_from_problem(problem: dict[str, Any]) -> str:
60
+ payload = {
61
  "problem_id": problem["problem_id"],
62
+ "problem_type": problem["problem_type"],
63
  "difficulty": problem["difficulty_label"],
64
+ "attempt_number": 0,
65
+ "max_steps": MAX_STEPS_PER_EPISODE,
66
+ "problem": format_examples(problem),
67
  "input_format": problem["input_format"],
68
  "constraints": problem["constraints"],
69
+ "feedback": "No previous attempt yet. Solve the problem directly from the examples and constraints.",
70
  }
71
+ return build_solver_prompt(payload)
 
 
 
 
 
 
72
 
73
 
74
  @dataclass
 
77
  current_idx: int = 0
78
  success_history: list[float] = field(default_factory=list)
79
  window_size: int = 10
80
+ promote_threshold: float = 0.70
81
+ demote_threshold: float = 0.30
82
 
83
  def current_difficulty(self) -> str:
84
  return self.difficulties[self.current_idx]
85
 
86
+ def current_level(self) -> int:
87
+ return self.current_idx + 1
88
+
89
+ def update(self, episode_pass_rate: float) -> None:
90
+ self.success_history.append(float(episode_pass_rate))
91
  if len(self.success_history) > self.window_size:
92
  self.success_history.pop(0)
93
 
94
+ if len(self.success_history) < self.window_size:
95
+ return
96
+
97
  moving_average = sum(self.success_history) / len(self.success_history)
98
+ if moving_average >= self.promote_threshold and self.current_idx < len(self.difficulties) - 1:
99
  self.current_idx += 1
100
  self.success_history.clear()
101
  print(
102
  f"[curriculum] promoted to {self.current_difficulty()} "
103
+ f"(moving_pass_rate={moving_average:.2f})"
104
  )
105
+ elif moving_average <= self.demote_threshold and self.current_idx > 0:
106
  self.current_idx -= 1
107
  self.success_history.clear()
108
  print(
109
+ f"[curriculum] demoted to {self.current_difficulty()} "
110
+ f"(moving_pass_rate={moving_average:.2f})"
111
  )
112
 
113
 
 
115
  class GeneratorController:
116
  mode: str = "heuristic"
117
  deterministic: bool = True
118
+ temperature: float = 0.5
119
  generator: GeneratorAgent = field(init=False)
120
  history: dict[str, Any] = field(
121
  default_factory=lambda: {
 
127
  }
128
  )
129
  prompt_registry: dict[str, dict[str, Any]] = field(default_factory=dict)
130
+ family_productivity: dict[str, float] = field(default_factory=dict)
131
 
132
  def __post_init__(self) -> None:
133
  self.generator = GeneratorAgent(deterministic=self.deterministic)
134
+ if not self.family_productivity:
135
+ self.family_productivity = {
136
+ template.problem_type: 0.0 for template in self.generator.templates
137
+ }
138
+
139
+ @property
140
+ def family_names(self) -> list[str]:
141
+ return sorted(self.family_productivity)
142
+
143
+ def sample_problem(self, difficulty: str) -> dict[str, Any]:
144
+ family_weights = self.family_weights_for_difficulty(difficulty)
145
+ problem = self.generator.generate_problem(
146
+ difficulty_level=difficulty,
147
+ history=self.history,
148
+ family_weights=family_weights,
149
+ )
150
+ return problem
151
 
152
  def create_rollout_problem(self, difficulty: str) -> tuple[str, dict[str, Any]]:
153
+ problem = self.sample_problem(difficulty)
154
+ prompt = build_prompt_from_problem(problem)
155
  self.prompt_registry[prompt] = problem
156
  return prompt, problem
157
 
158
  def resolve_prompt(self, prompt: str) -> dict[str, Any]:
159
  if prompt not in self.prompt_registry:
160
  raise KeyError("Prompt was not registered with the generator controller.")
161
+ return self.prompt_registry.pop(prompt)
162
+
163
+ def family_weights_for_difficulty(self, difficulty: str) -> dict[str, float] | None:
164
+ if self.mode != "reward_aware":
165
+ return None
166
+
167
+ eligible = [
168
+ template.problem_type
169
+ for template in self.generator.templates
170
+ if DIFFICULTY_LABELS[template.difficulty_tier] == difficulty
171
+ ]
172
+ if not eligible:
173
+ return None
174
+
175
+ logits = [self.family_productivity.get(family, 0.0) / self.temperature for family in eligible]
176
+ max_logit = max(logits)
177
+ exp_values = [math.exp(logit - max_logit) for logit in logits]
178
+ return {family: value for family, value in zip(eligible, exp_values)}
179
+
180
+ def update(
181
+ self,
182
+ problem: dict[str, Any],
183
+ pass_rate: float,
184
+ generator_reward_signal: float,
185
+ *,
186
+ update_productivity: bool = True,
187
+ ) -> None:
188
  self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
189
  self.history["problem_types"].append(problem.get("problem_type", ""))
190
  self.history["problem_signatures"].append(problem.get("problem_id", ""))
191
+ self.history["generator_rewards"].append(round(float(generator_reward_signal), 4))
 
 
 
192
  self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
193
 
194
+ if self.mode == "reward_aware" and update_productivity:
195
+ family = problem.get("problem_type", "")
196
+ current = float(self.family_productivity.get(family, 0.0))
197
+ updated = 0.9 * current + 0.1 * float(generator_reward_signal)
198
+ self.family_productivity[family] = round(updated, 6)
199
+
200
  for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
201
  values = self.history[key]
202
+ if len(values) > 100:
203
+ del values[:-100]
204
 
205
  def stats_snapshot(self) -> dict[str, Any]:
206
  return {
 
209
  "recent_pass_rates": list(self.history["recent_pass_rates"][-5:]),
210
  "recent_problem_types": list(self.history["problem_types"][-5:]),
211
  "recent_generator_rewards": list(self.history["generator_rewards"][-5:]),
212
+ "family_productivity": self.productivity_snapshot(),
213
+ }
214
+
215
+ def productivity_snapshot(self) -> dict[str, float]:
216
+ return {
217
+ family: round(float(value), 6)
218
+ for family, value in sorted(self.family_productivity.items())
219
  }
220
 
221
 
 
234
  return {"prompt": prompt}
235
 
236
 
237
+ @dataclass
238
+ class TrainingLogger:
239
+ output_dir: Path
240
+ family_names: list[str]
241
+ use_wandb: bool = True
242
+ wandb_project: str = "adapt-dsa-tutor"
243
+ wandb_run_name: str | None = None
244
+ rows: list[dict[str, Any]] = field(default_factory=list)
245
+ global_step: int = 0
246
+ _wandb_run: Any = field(default=None, init=False, repr=False)
247
+
248
+ def __post_init__(self) -> None:
249
+ self.output_dir.mkdir(parents=True, exist_ok=True)
250
+ if not self.use_wandb:
251
+ return
252
+ try:
253
+ import wandb
254
+
255
+ self._wandb_run = wandb.init(
256
+ project=self.wandb_project,
257
+ name=self.wandb_run_name,
258
+ config={"family_names": self.family_names},
259
+ reinit=True,
260
+ )
261
+ except Exception:
262
+ self._wandb_run = None
263
+
264
+ def log_event(
265
+ self,
266
+ *,
267
+ phase: str,
268
+ episode_reward: float,
269
+ pass_rate: float,
270
+ visible_pass_rate: float,
271
+ difficulty_tier: str,
272
+ problem_family: str,
273
+ curriculum_level: int,
274
+ execution_status: str,
275
+ attempt_number: int,
276
+ family_productivity: dict[str, float],
277
+ extra: dict[str, Any] | None = None,
278
+ ) -> None:
279
+ row: dict[str, Any] = {
280
+ "step": self.global_step,
281
+ "phase": phase,
282
+ "episode_reward": round(float(episode_reward), 4),
283
+ "pass_rate": round(float(pass_rate), 4),
284
+ "visible_pass_rate": round(float(visible_pass_rate), 4),
285
+ "difficulty_tier": difficulty_tier,
286
+ "problem_family": problem_family,
287
+ "curriculum_level": curriculum_level,
288
+ "execution_status": execution_status,
289
+ "attempt_number": int(attempt_number),
290
+ }
291
+ for family in self.family_names:
292
+ row[f"family_productivity__{family}"] = round(float(family_productivity.get(family, 0.0)), 6)
293
+ if extra:
294
+ row.update(extra)
295
+ self.rows.append(row)
296
+ if self._wandb_run is not None:
297
+ self._wandb_run.log(row, step=self.global_step)
298
+ self.global_step += 1
299
+
300
+ def write_csv(self) -> Path:
301
+ output_path = self.output_dir / "reward_curve.csv"
302
+ fieldnames: list[str] = []
303
+ for row in self.rows:
304
+ for key in row:
305
+ if key not in fieldnames:
306
+ fieldnames.append(key)
307
+ with output_path.open("w", newline="", encoding="utf-8") as handle:
308
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
309
+ writer.writeheader()
310
+ writer.writerows(self.rows)
311
+ return output_path
312
+
313
+ def close(self) -> None:
314
+ if self._wandb_run is not None:
315
+ self._wandb_run.finish()
316
+
317
+
318
+ def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
319
+ return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
320
+
321
+
322
+ def build_reward_func(
323
+ curriculum: CurriculumManager,
324
+ controller: GeneratorController,
325
+ logger: TrainingLogger,
326
+ ):
327
  def reward_func(prompts, completions, **kwargs) -> list[float]:
328
  del kwargs
 
329
  rewards: list[float] = []
 
330
 
331
  for prompt, completion in zip(prompts, completions):
332
  problem = controller.resolve_prompt(prompt)
333
+ env = AdaptEnvironment(generator=controller.generator, generator_mode=controller.mode)
334
  env.reset(
335
  difficulty=problem["difficulty_label"],
336
  generated_problem=problem,
337
  generator_mode=controller.mode,
338
+ session_id=env.session_id,
339
+ )
340
+ observation = env.step(
341
+ AdaptAction(
342
+ session_id=env.session_id,
343
+ code=extract_code(completion),
344
+ )
345
  )
 
346
  rewards.append(float(observation.reward))
347
+ controller.update(
348
+ problem=problem,
349
+ pass_rate=observation.pass_rate,
350
+ generator_reward_signal=observation.generator_reward_signal,
 
 
 
 
 
 
 
 
 
 
 
351
  )
352
+ curriculum.update(observation.pass_rate)
353
+ logger.log_event(
354
+ phase="train",
355
+ episode_reward=float(observation.reward),
356
+ pass_rate=float(observation.pass_rate),
357
+ visible_pass_rate=float(observation.visible_pass_rate),
358
+ difficulty_tier=problem["difficulty_label"],
359
+ problem_family=problem["problem_type"],
360
+ curriculum_level=curriculum.current_level(),
361
+ execution_status=observation.execution_status,
362
+ attempt_number=int(observation.attempt_number),
363
+ family_productivity=controller.productivity_snapshot(),
364
+ extra={
365
+ "generator_reward": round(float(observation.generator_reward_signal), 4),
366
+ "problem_id": problem["problem_id"],
367
+ },
368
+ )
369
+ if controller.mode == "reward_aware" and controller.history["episode_index"] % 50 == 0:
370
+ print("[family_productivity]", json.dumps(controller.productivity_snapshot()))
371
 
372
  return rewards
373
 
374
  return reward_func
375
 
376
 
377
+ def render_prompt(tokenizer: Any, prompt: str) -> str:
378
+ messages = [
379
+ {"role": "system", "content": SYSTEM_PROMPT},
380
+ {"role": "user", "content": prompt},
381
+ ]
382
+ if hasattr(tokenizer, "apply_chat_template"):
383
+ return tokenizer.apply_chat_template(
384
+ messages,
385
+ tokenize=False,
386
+ add_generation_prompt=True,
387
+ )
388
+ return f"{SYSTEM_PROMPT}\n\n{prompt}"
389
+
390
+
391
+ def generate_completion(
392
+ model: Any,
393
+ tokenizer: Any,
394
+ prompt: str,
395
+ *,
396
+ max_new_tokens: int,
397
+ ) -> str:
398
+ rendered = render_prompt(tokenizer, prompt)
399
+ inputs = tokenizer(rendered, return_tensors="pt")
400
+ device = getattr(model, "device", None)
401
+ if device is None:
402
+ device = next(model.parameters()).device
403
+ inputs = {key: value.to(device) for key, value in inputs.items()}
404
+ with torch.no_grad():
405
+ outputs = model.generate(
406
+ **inputs,
407
+ max_new_tokens=max_new_tokens,
408
+ do_sample=False,
409
+ pad_token_id=tokenizer.eos_token_id,
410
+ )
411
+ generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
412
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True)
413
+
414
+
415
+ def run_policy_evaluation(
416
+ *,
417
+ model: Any,
418
+ tokenizer: Any,
419
+ generator_mode: str,
420
+ deterministic_generator: bool,
421
+ episodes: int,
422
+ logger: TrainingLogger,
423
+ phase: str,
424
+ max_new_tokens: int,
425
+ ) -> dict[str, Any]:
426
+ controller = GeneratorController(
427
+ mode=generator_mode,
428
+ deterministic=deterministic_generator,
429
+ )
430
+ schedule = ["easy"] * (episodes // 3 + (1 if episodes % 3 > 0 else 0))
431
+ schedule += ["medium"] * (episodes // 3 + (1 if episodes % 3 > 1 else 0))
432
+ schedule += ["hard"] * (episodes // 3)
433
+ schedule = schedule[:episodes]
434
+
435
+ tier_records: dict[str, list[float]] = {"easy": [], "medium": [], "hard": []}
436
+
437
+ for difficulty in schedule:
438
+ problem = controller.sample_problem(difficulty)
439
+ env = AdaptEnvironment(generator=controller.generator, generator_mode=generator_mode)
440
+ observation = env.reset(
441
+ difficulty=difficulty,
442
+ generated_problem=problem,
443
+ session_id=env.session_id,
444
+ generator_mode=generator_mode,
445
+ )
446
+
447
+ for _ in range(MAX_STEPS_PER_EPISODE):
448
+ prompt = build_solver_prompt(observation.model_dump())
449
+ completion = generate_completion(
450
+ model=model,
451
+ tokenizer=tokenizer,
452
+ prompt=prompt,
453
+ max_new_tokens=max_new_tokens,
454
+ )
455
+ observation = env.step(
456
+ AdaptAction(
457
+ session_id=env.session_id,
458
+ code=extract_code(completion),
459
+ )
460
+ )
461
+ if observation.done:
462
+ break
463
+
464
+ controller.update(
465
+ problem=problem,
466
+ pass_rate=observation.pass_rate,
467
+ generator_reward_signal=observation.generator_reward_signal,
468
+ update_productivity=False,
469
+ )
470
+ tier_records[difficulty].append(float(observation.pass_rate))
471
+ logger.log_event(
472
+ phase=phase,
473
+ episode_reward=float(observation.reward),
474
+ pass_rate=float(observation.pass_rate),
475
+ visible_pass_rate=float(observation.visible_pass_rate),
476
+ difficulty_tier=difficulty,
477
+ problem_family=problem["problem_type"],
478
+ curriculum_level={"easy": 1, "medium": 2, "hard": 3}[difficulty],
479
+ execution_status=observation.execution_status,
480
+ attempt_number=int(observation.attempt_number),
481
+ family_productivity=controller.productivity_snapshot(),
482
+ extra={
483
+ "generator_reward": round(float(observation.generator_reward_signal), 4),
484
+ "problem_id": problem["problem_id"],
485
+ },
486
+ )
487
+
488
+ summary = {
489
+ tier: (sum(values) / len(values) if values else 0.0)
490
+ for tier, values in tier_records.items()
491
+ }
492
+ summary["overall"] = (
493
+ sum(value for values in tier_records.values() for value in values) / episodes if episodes else 0.0
494
+ )
495
+ return summary
496
+
497
+
498
+ def print_evaluation_summary(baseline: dict[str, Any], trained: dict[str, Any]) -> None:
499
+ print("\nBaseline vs trained pass rate summary")
500
+ print(f"{'Difficulty':<12} {'Baseline':>10} {'Trained':>10}")
501
+ print("-" * 34)
502
+ for tier in ("easy", "medium", "hard", "overall"):
503
+ print(f"{tier:<12} {baseline.get(tier, 0.0):>10.3f} {trained.get(tier, 0.0):>10.3f}")
504
 
505
 
506
  def run_training(args: argparse.Namespace) -> None:
 
512
  "Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
513
  ) from exc
514
 
515
+ output_dir = Path(args.output_dir)
516
+ output_dir.mkdir(parents=True, exist_ok=True)
517
+
518
  PatchFastRL("GRPO", FastLanguageModel)
519
 
520
  model, tokenizer = FastLanguageModel.from_pretrained(
 
522
  max_seq_length=args.max_seq_length,
523
  load_in_4bit=not args.disable_4bit,
524
  )
525
+ if tokenizer.pad_token is None:
526
+ tokenizer.pad_token = tokenizer.eos_token
527
 
528
  model = FastLanguageModel.get_peft_model(
529
  model,
 
538
  mode="reward_aware" if args.generator_mode == "reward_aware" else "heuristic",
539
  deterministic=not args.non_deterministic_generator,
540
  )
541
+ logger = TrainingLogger(
542
+ output_dir=output_dir,
543
+ family_names=controller.family_names,
544
+ use_wandb=not args.disable_wandb,
545
+ wandb_project=args.wandb_project,
546
+ wandb_run_name=args.wandb_run_name,
547
+ )
548
+
549
+ baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
550
+ trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
551
+
552
+ if args.baseline_eval:
553
+ FastLanguageModel.for_inference(model)
554
+ baseline_summary = run_policy_evaluation(
555
+ model=model,
556
+ tokenizer=tokenizer,
557
+ generator_mode=controller.mode,
558
+ deterministic_generator=not args.non_deterministic_generator,
559
+ episodes=args.evaluation_episodes,
560
+ logger=logger,
561
+ phase="baseline_eval",
562
+ max_new_tokens=args.eval_max_new_tokens,
563
+ )
564
+ print(f"[baseline_eval] {json.dumps(baseline_summary)}")
565
+
566
  training_args = GRPOConfig(
567
  output_dir=args.output_dir,
568
  learning_rate=args.learning_rate,
 
574
  max_steps=args.max_steps,
575
  logging_steps=1,
576
  bf16=args.bf16,
577
+ report_to=[],
578
  )
579
 
580
  trainer = GRPOTrainer(
581
  model=model,
582
+ reward_funcs=[build_reward_func(curriculum, controller, logger)],
583
  args=training_args,
584
  train_dataset=build_dataset(args.dataset_size, controller, curriculum),
585
  )
586
  trainer.train()
587
+
588
  model.save_pretrained(args.output_dir)
589
  tokenizer.save_pretrained(args.output_dir)
590
 
591
+ if args.baseline_eval:
592
+ FastLanguageModel.for_inference(model)
593
+ trained_summary = run_policy_evaluation(
594
+ model=model,
595
+ tokenizer=tokenizer,
596
+ generator_mode=controller.mode,
597
+ deterministic_generator=not args.non_deterministic_generator,
598
+ episodes=args.evaluation_episodes,
599
+ logger=logger,
600
+ phase="trained_eval",
601
+ max_new_tokens=args.eval_max_new_tokens,
602
+ )
603
+ print(f"[trained_eval] {json.dumps(trained_summary)}")
604
+ print_evaluation_summary(baseline_summary, trained_summary)
605
+
606
+ csv_path = logger.write_csv()
607
+ logger.close()
608
+ print(f"[artifacts] reward curve CSV written to {csv_path}")
609
+
610
 
611
  def build_parser() -> argparse.ArgumentParser:
612
  parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
613
  parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
614
+ parser.add_argument("--output-dir", default="outputs_v3")
615
  parser.add_argument("--dataset-size", type=int, default=200)
616
  parser.add_argument("--max-steps", type=int, default=250)
617
  parser.add_argument("--batch-size", type=int, default=1)
618
  parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
619
  parser.add_argument("--num-generations", type=int, default=8)
620
  parser.add_argument("--max-seq-length", type=int, default=2048)
621
+ parser.add_argument("--max-prompt-length", type=int, default=1024)
622
  parser.add_argument("--max-completion-length", type=int, default=512)
623
  parser.add_argument("--learning-rate", type=float, default=5e-6)
624
  parser.add_argument("--lora-rank", type=int, default=16)
625
  parser.add_argument("--lora-alpha", type=int, default=16)
626
  parser.add_argument("--disable-4bit", action="store_true")
627
  parser.add_argument("--bf16", action="store_true")
628
+ parser.add_argument("--baseline-eval", action="store_true")
629
+ parser.add_argument("--evaluation-episodes", type=int, default=20)
630
+ parser.add_argument("--eval-max-new-tokens", type=int, default=512)
631
+ parser.add_argument("--disable-wandb", action="store_true")
632
+ parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
633
+ parser.add_argument("--wandb-run-name", default=None)
634
  parser.add_argument(
635
  "--generator-mode",
636
  choices=["heuristic", "reward_aware"],
637
+ default="reward_aware",
638
+ help="Use heuristic generation or reward-aware family weighting.",
639
  )
640
  parser.add_argument(
641
  "--non-deterministic-generator",
verifier/metrics.py CHANGED
@@ -3,30 +3,54 @@ from __future__ import annotations
3
  from typing import Any
4
 
5
 
6
- def compute_pass_rate(results: list[dict[str, Any]]) -> tuple[float, dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  total = len(results)
 
 
 
 
 
 
 
 
8
  passed = sum(1 for result in results if result["passed"])
 
9
  timeout_count = sum(1 for result in results if result["status"] == "timeout")
10
  runtime_error_count = sum(1 for result in results if result["status"] == "runtime_error")
11
  invalid_output_count = sum(1 for result in results if result["status"] == "invalid_output_format")
12
  wrong_answer_count = sum(1 for result in results if result["status"] == "wrong_answer")
13
  format_ok_count = sum(1 for result in results if result.get("format_ok", False))
14
 
15
- pass_rate = passed / total if total else 0.0
 
 
16
  format_compliance = format_ok_count / total if total else 0.0
17
- timeout_rate = timeout_count / total if total else 0.0
18
- runtime_error_rate = runtime_error_count / total if total else 0.0
19
- invalid_output_rate = invalid_output_count / total if total else 0.0
20
-
21
- reward_components = {
22
- "correctness": 0.8 * pass_rate,
23
- "format": 0.1 * format_compliance,
24
- "execution": 0.1 if timeout_count == 0 and runtime_error_count == 0 else 0.0,
25
- "timeout_penalty": -0.2 * timeout_rate,
26
- "runtime_penalty": -0.1 * runtime_error_rate,
27
- "invalid_output_penalty": -0.1 * invalid_output_rate,
28
- }
29
- reward = max(0.0, min(1.0, sum(reward_components.values())))
30
 
31
  if timeout_count:
32
  execution_status = "timeout"
@@ -39,10 +63,23 @@ def compute_pass_rate(results: list[dict[str, Any]]) -> tuple[float, dict[str, A
39
  else:
40
  execution_status = "completed"
41
 
 
 
 
 
 
 
 
42
  return reward, {
43
  "passed": passed,
44
  "total": total,
 
 
 
 
45
  "pass_rate": round(pass_rate, 4),
 
 
46
  "timeout_count": timeout_count,
47
  "runtime_error_count": runtime_error_count,
48
  "invalid_output_count": invalid_output_count,
@@ -50,6 +87,8 @@ def compute_pass_rate(results: list[dict[str, Any]]) -> tuple[float, dict[str, A
50
  "format_compliance": round(format_compliance, 4),
51
  "execution_status": execution_status,
52
  "reward_components": {
53
- key: round(float(value), 4) for key, value in reward_components.items()
 
 
54
  },
55
  }
 
3
  from typing import Any
4
 
5
 
6
+ def compute_reward(
7
+ pass_rate: float,
8
+ step_number: int,
9
+ execution_status: str,
10
+ format_compliance: float,
11
+ ) -> float:
12
+ """
13
+ Clean, interpretable reward signal for GRPO training.
14
+ """
15
+ del format_compliance
16
+
17
+ step_discount = 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70)
18
+ correctness = pass_rate
19
+
20
+ if execution_status == "timeout":
21
+ return 0.0
22
+ if execution_status == "syntax_error":
23
+ return 0.0
24
+
25
+ reward = correctness * step_discount
26
+ return round(min(max(reward, 0.0), 1.0), 4)
27
+
28
+
29
+ def compute_pass_rate(
30
+ results: list[dict[str, Any]],
31
+ step_number: int = 1,
32
+ ) -> tuple[float, dict[str, Any]]:
33
  total = len(results)
34
+ hidden_results = [result for result in results if result.get("visibility") == "hidden"]
35
+ visible_results = [result for result in results if result.get("visibility") == "visible"]
36
+
37
+ hidden_total = len(hidden_results)
38
+ visible_total = len(visible_results)
39
+
40
+ hidden_passed = sum(1 for result in hidden_results if result["passed"])
41
+ visible_passed = sum(1 for result in visible_results if result["passed"])
42
  passed = sum(1 for result in results if result["passed"])
43
+
44
  timeout_count = sum(1 for result in results if result["status"] == "timeout")
45
  runtime_error_count = sum(1 for result in results if result["status"] == "runtime_error")
46
  invalid_output_count = sum(1 for result in results if result["status"] == "invalid_output_format")
47
  wrong_answer_count = sum(1 for result in results if result["status"] == "wrong_answer")
48
  format_ok_count = sum(1 for result in results if result.get("format_ok", False))
49
 
50
+ hidden_pass_rate = hidden_passed / hidden_total if hidden_total else 0.0
51
+ visible_pass_rate = visible_passed / visible_total if visible_total else 0.0
52
+ pass_rate = hidden_pass_rate if hidden_total else (passed / total if total else 0.0)
53
  format_compliance = format_ok_count / total if total else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if timeout_count:
56
  execution_status = "timeout"
 
63
  else:
64
  execution_status = "completed"
65
 
66
+ reward = compute_reward(
67
+ pass_rate=pass_rate,
68
+ step_number=step_number,
69
+ execution_status=execution_status,
70
+ format_compliance=format_compliance,
71
+ )
72
+
73
  return reward, {
74
  "passed": passed,
75
  "total": total,
76
+ "hidden_passed": hidden_passed,
77
+ "hidden_total": hidden_total,
78
+ "visible_passed": visible_passed,
79
+ "visible_total": visible_total,
80
  "pass_rate": round(pass_rate, 4),
81
+ "hidden_pass_rate": round(hidden_pass_rate, 4),
82
+ "visible_pass_rate": round(visible_pass_rate, 4),
83
  "timeout_count": timeout_count,
84
  "runtime_error_count": runtime_error_count,
85
  "invalid_output_count": invalid_output_count,
 
87
  "format_compliance": round(format_compliance, 4),
88
  "execution_status": execution_status,
89
  "reward_components": {
90
+ "correctness": round(float(pass_rate), 4),
91
+ "step_discount": 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70),
92
+ "reward": reward,
93
  },
94
  }