Spaces:
Running
Running
Commit ·
5b695bd
1
Parent(s): 267d60a
v4
Browse files- README.md +165 -111
- client.py +13 -3
- env/adapt_env.py +301 -67
- env/generator.py +990 -131
- env/test_cases.py +1 -3
- models.py +11 -3
- openenv.yaml +3 -31
- scripts/test_env.py +42 -15
- server/app.py +71 -15
- training/plot_results.py +139 -0
- training/train_grpo.py +436 -60
- verifier/metrics.py +55 -16
README.md
CHANGED
|
@@ -8,196 +8,250 @@ tags:
|
|
| 8 |
- openenv
|
| 9 |
- reinforcement-learning
|
| 10 |
- code-generation
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# ADAPT DSA
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
```text
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
```
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
-
|
| 34 |
-
- `reset()` returns a typed observation.
|
| 35 |
-
- `step(action)` accepts an `AdaptAction` with a Python `code` string.
|
| 36 |
-
- `state` exposes episode id, step count, current problem id, difficulty, and recent metrics.
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
```
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
```python
|
| 48 |
-
|
| 49 |
-
"code": "n = int(input())\nprint(n * 2)"
|
| 50 |
-
}
|
| 51 |
```
|
| 52 |
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
- input format
|
| 59 |
-
- constraints
|
| 60 |
-
- examples
|
| 61 |
-
- visible tests
|
| 62 |
-
- problem id
|
| 63 |
-
- difficulty tier
|
| 64 |
-
- feedback
|
| 65 |
-
- pass rate, visible pass rate, and hidden pass rate
|
| 66 |
-
- syntax/runtime/timeout status
|
| 67 |
-
- reward components
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
- syntax validity
|
| 77 |
-
- clean execution
|
| 78 |
-
- output format compliance
|
| 79 |
-
- timeout penalty
|
| 80 |
-
- runtime error penalty
|
| 81 |
-
- static safety rejection for dangerous imports such as `os`, `subprocess`, `socket`, `pathlib`, and `shutil`
|
| 82 |
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
```
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
-
```
|
| 98 |
-
$env:PYTHONPATH="C:\Users\kaust\PycharmProjects\OpenEnv\src;$PWD"
|
| 99 |
-
```
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
|
| 105 |
-
```
|
| 106 |
-
|
| 107 |
```
|
| 108 |
|
| 109 |
-
|
| 110 |
|
| 111 |
-
```
|
| 112 |
-
|
| 113 |
```
|
| 114 |
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
|
| 123 |
-
|
| 124 |
-
- `GET /health`
|
| 125 |
-
- `GET /metadata`
|
| 126 |
-
- `GET /tasks`
|
| 127 |
-
- `GET /schema`
|
| 128 |
-
- `POST /reset`
|
| 129 |
-
- `POST /step`
|
| 130 |
-
- `GET /state`
|
| 131 |
-
- `POST /mcp`
|
| 132 |
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
```powershell
|
| 136 |
-
|
| 137 |
```
|
| 138 |
|
| 139 |
-
|
| 140 |
|
| 141 |
```powershell
|
| 142 |
-
|
| 143 |
```
|
| 144 |
|
| 145 |
-
|
| 146 |
|
| 147 |
```powershell
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
```
|
| 150 |
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
```powershell
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
```
|
| 156 |
|
| 157 |
-
|
| 158 |
|
| 159 |
```powershell
|
| 160 |
-
|
| 161 |
```
|
| 162 |
|
| 163 |
-
|
| 164 |
|
| 165 |
```powershell
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
```
|
| 171 |
|
| 172 |
-
|
| 173 |
|
| 174 |
```powershell
|
| 175 |
-
python training\
|
| 176 |
```
|
| 177 |
|
| 178 |
-
## Hugging Face
|
| 179 |
|
| 180 |
-
This repo is
|
| 181 |
|
| 182 |
```powershell
|
| 183 |
openenv push --repo-id <your-hf-username>/adapt-dsa-tutor
|
| 184 |
```
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
- live Hugging Face Space link
|
| 189 |
-
- training reward/loss plots from Disha's run
|
| 190 |
-
- before/after code example showing a problem the model failed before training and solved after training
|
| 191 |
-
- mini-blog or short video link
|
| 192 |
-
|
| 193 |
-
## Current Problem Bank
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
- `easy_sum_two`
|
| 199 |
-
- `medium_maximum`
|
| 200 |
-
- `medium_count_even`
|
| 201 |
-
- `hard_reverse_words`
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
- openenv
|
| 9 |
- reinforcement-learning
|
| 10 |
- code-generation
|
| 11 |
+
- llm-training
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# ADAPT: Adversarial DSA Programming Tutor
|
| 15 |
|
| 16 |
+
LLMs are getting better at one-shot code generation, but they still struggle with the thing real engineers do all day: read feedback, debug, and repair. ADAPT closes that gap by turning algorithm practice into a self-repair RL environment where the model must improve over multiple attempts instead of guessing once.
|
| 17 |
|
| 18 |
+
## Why ADAPT exists
|
| 19 |
|
| 20 |
+
Most code-generation benchmarks test whether a model can land the answer immediately. They do not test whether the model can recover from partial failure, use examples productively, or adapt as the task distribution changes.
|
| 21 |
|
| 22 |
+
ADAPT is built to stress exactly those capabilities:
|
| 23 |
+
|
| 24 |
+
- adaptive difficulty across easy, medium, and hard DSA families
|
| 25 |
+
- visible examples plus hidden evaluation tests
|
| 26 |
+
- multi-step repair with feedback between attempts
|
| 27 |
+
- reward-aware problem generation that shifts toward the most educational families
|
| 28 |
+
|
| 29 |
+
## Architecture
|
| 30 |
|
| 31 |
```text
|
| 32 |
+
+------------+ +-----------+ +----------+ +-----------+
|
| 33 |
+
| Generator | --> | Problem | --> | Solver | --> | Execution |
|
| 34 |
+
+------------+ +-----------+ +----------+ +-----------+
|
| 35 |
+
^ |
|
| 36 |
+
| v
|
| 37 |
+
+------------- Curriculum <- Reward <- Verification -----+
|
| 38 |
```
|
| 39 |
|
| 40 |
+
## What the agent sees, does, and gets rewarded for
|
| 41 |
+
|
| 42 |
+
The agent sees a plain-English programming problem, the stdin format, constraints, and two worked examples. It writes Python code that reads from stdin and prints to stdout.
|
| 43 |
+
|
| 44 |
+
The environment executes that code on 10 tests per problem:
|
| 45 |
+
|
| 46 |
+
- 2 visible tests shown as examples
|
| 47 |
+
- 8 hidden tests used for the real pass-rate reward
|
| 48 |
|
| 49 |
+
After each attempt, the environment returns:
|
| 50 |
|
| 51 |
+
- hidden pass rate
|
| 52 |
+
- visible pass rate
|
| 53 |
+
- execution status such as `completed`, `wrong_answer`, `runtime_error`, or `timeout`
|
| 54 |
+
- a compact list of which tests failed
|
| 55 |
+
- enough context to try again on the same problem
|
| 56 |
|
| 57 |
+
## Multi-step repair loop
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
Each episode allows up to 3 attempts on the same problem.
|
| 60 |
|
| 61 |
+
1. Attempt 1: the agent submits a first solution.
|
| 62 |
+
2. Feedback: ADAPT reports the current execution status, hidden pass rate, visible pass rate, and which visible/hidden tests failed.
|
| 63 |
+
3. Attempt 2 or 3: the agent repairs its code using that feedback.
|
| 64 |
+
4. The episode ends early if all hidden tests pass.
|
| 65 |
+
|
| 66 |
+
Concrete example:
|
| 67 |
+
|
| 68 |
+
```text
|
| 69 |
+
Problem family: running_total
|
| 70 |
+
|
| 71 |
+
Attempt 1 code:
|
| 72 |
+
print(sum(nums))
|
| 73 |
+
|
| 74 |
+
Feedback:
|
| 75 |
+
Attempt 1/3
|
| 76 |
+
Previous attempt status: ready
|
| 77 |
+
Current execution status: wrong_answer
|
| 78 |
+
Hidden pass rate: 0.25
|
| 79 |
+
Visible pass rate: 0.50
|
| 80 |
+
Failed tests:
|
| 81 |
+
- Visible test #2: wrong_answer (expected=5 3 10, got=10)
|
| 82 |
+
- Hidden test #1: wrong_answer
|
| 83 |
+
- Hidden test #4: wrong_answer
|
| 84 |
+
|
| 85 |
+
Attempt 2 code:
|
| 86 |
+
running = 0
|
| 87 |
+
for x in nums:
|
| 88 |
+
running += x
|
| 89 |
+
out.append(str(running))
|
| 90 |
+
print(" ".join(out))
|
| 91 |
```
|
| 92 |
|
| 93 |
+
That repair loop is the core novelty of ADAPT: the model is rewarded for debugging, not just for lucky first drafts.
|
| 94 |
+
|
| 95 |
+
## Reward function
|
| 96 |
+
|
| 97 |
+
ADAPT uses a clean reward signal driven by hidden correctness:
|
| 98 |
|
| 99 |
```python
|
| 100 |
+
reward = hidden_pass_rate * step_discount
|
|
|
|
|
|
|
| 101 |
```
|
| 102 |
|
| 103 |
+
Where:
|
| 104 |
|
| 105 |
+
- `step_discount = 1.00` on attempt 1
|
| 106 |
+
- `step_discount = 0.85` on attempt 2
|
| 107 |
+
- `step_discount = 0.70` on attempt 3
|
| 108 |
|
| 109 |
+
Additional shaping for the repair loop:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
- if a failed non-terminal attempt improves hidden pass rate, reward = `0.1 * delta_pass_rate`
|
| 112 |
+
- if the final attempt still fails, reward = `0.0`
|
| 113 |
+
- timeouts and syntax errors always get `0.0`
|
| 114 |
|
| 115 |
+
Examples:
|
| 116 |
|
| 117 |
+
- attempt 1 solves all 8 hidden tests: reward = `1.0`
|
| 118 |
+
- attempt 2 solves all 8 hidden tests: reward = `0.85`
|
| 119 |
+
- attempt 1 improves from `0.25` to `0.50` hidden pass rate on a retry trajectory: reward = `0.025`
|
| 120 |
+
- attempt 3 still fails: reward = `0.0`
|
| 121 |
|
| 122 |
+
## Problem families
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
ADAPT now covers 20 algorithmic families instead of a tiny fixed bank:
|
| 125 |
|
| 126 |
+
- Easy: `sum_even_numbers`, `range_span`, `count_vowels`, `max_consecutive_ones`, `fizzbuzz_variant`, `running_total`
|
| 127 |
+
- Medium: `count_local_peaks`, `longest_non_decreasing_run`, `two_sum_count`, `max_subarray_sum`, `group_anagrams_count`, `balanced_brackets`, `matrix_diagonal_sum`
|
| 128 |
+
- Hard: `smallest_most_frequent`, `reverse_words`, `longest_common_subsequence`, `word_ladder_steps`, `merge_intervals`, `min_coins`, `rotate_matrix_90`
|
| 129 |
|
| 130 |
+
Every family has:
|
| 131 |
|
| 132 |
+
- its own randomized case generator
|
| 133 |
+
- 2 visible example tests
|
| 134 |
+
- 8 hidden evaluation tests
|
| 135 |
+
- a reference solver that auto-generates expected outputs
|
|
|
|
| 136 |
|
| 137 |
+
## Self-improving curriculum
|
| 138 |
|
| 139 |
+
ADAPT uses one curriculum authority in training: the `CurriculumManager` inside `training/train_grpo.py`.
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
- promote threshold: `0.70`
|
| 142 |
+
- demote threshold: `0.30`
|
| 143 |
+
- moving-average window: `10` episodes
|
| 144 |
|
| 145 |
+
On top of that, the generator tracks `family_productivity`, an EMA of how educational each family is:
|
| 146 |
|
| 147 |
+
```text
|
| 148 |
+
family_productivity[family] = 0.9 * old + 0.1 * generator_reward
|
| 149 |
```
|
| 150 |
|
| 151 |
+
Families that produce pass rates near the learning sweet spot, around `0.5`, become more likely to be sampled via a softmax distribution. This creates a closed loop:
|
| 152 |
|
| 153 |
+
```text
|
| 154 |
+
productive families -> more samples -> better learning signal -> updated family productivity
|
| 155 |
```
|
| 156 |
|
| 157 |
+
That makes ADAPT more than a static benchmark. The environment actively searches for the problems that teach the model the most.
|
| 158 |
|
| 159 |
+
## Results
|
| 160 |
+
|
| 161 |
+
[INSERT: reward curve plot]
|
| 162 |
+
|
| 163 |
+
[INSERT: baseline vs trained table]
|
| 164 |
+
|
| 165 |
+
Recommended artifacts to include here:
|
| 166 |
+
|
| 167 |
+
- reward curve from `training/reward_curve.csv`
|
| 168 |
+
- `reward_curve.png`
|
| 169 |
+
- `pass_rate_by_difficulty.png`
|
| 170 |
+
- `family_productivity.png`
|
| 171 |
+
- one before/after repair example from baseline vs trained evaluation
|
| 172 |
|
| 173 |
+
## How to run
|
| 174 |
|
| 175 |
+
### 1. Install dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
```powershell
|
| 178 |
+
cd C:\Users\kaust\PycharmProjects\meta-rl-dsa-solver
|
| 179 |
+
python -m venv .venv
|
| 180 |
+
.\.venv\Scripts\pip install -e .
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
For training and plotting, also install your training extras:
|
| 184 |
|
| 185 |
```powershell
|
| 186 |
+
.\.venv\Scripts\pip install trl unsloth matplotlib wandb
|
| 187 |
```
|
| 188 |
|
| 189 |
+
### 2. Start the OpenEnv server
|
| 190 |
|
| 191 |
```powershell
|
| 192 |
+
python server\app.py
|
| 193 |
```
|
| 194 |
|
| 195 |
+
### 3. Reset an environment session
|
| 196 |
|
| 197 |
```powershell
|
| 198 |
+
curl -X POST http://localhost:7860/reset ^
|
| 199 |
+
-H "Content-Type: application/json" ^
|
| 200 |
+
-d "{\"difficulty\":\"easy\"}"
|
| 201 |
```
|
| 202 |
|
| 203 |
+
The response includes a `session_id`. Reuse it for `step` and `state`.
|
| 204 |
+
|
| 205 |
+
### 4. Submit code to `/step`
|
| 206 |
|
| 207 |
```powershell
|
| 208 |
+
curl -X POST http://localhost:7860/step ^
|
| 209 |
+
-H "Content-Type: application/json" ^
|
| 210 |
+
-d "{\"session_id\":\"<SESSION_ID>\",\"code\":\"n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))\"}"
|
| 211 |
```
|
| 212 |
|
| 213 |
+
### 5. Inspect current state
|
| 214 |
|
| 215 |
```powershell
|
| 216 |
+
curl "http://localhost:7860/state?session_id=<SESSION_ID>"
|
| 217 |
```
|
| 218 |
|
| 219 |
+
### 6. Run training
|
| 220 |
|
| 221 |
```powershell
|
| 222 |
+
python training\train_grpo.py ^
|
| 223 |
+
--generator-mode reward_aware ^
|
| 224 |
+
--baseline-eval ^
|
| 225 |
+
--output-dir outputs_v3
|
| 226 |
```
|
| 227 |
|
| 228 |
+
### 7. Plot the training curves
|
| 229 |
|
| 230 |
```powershell
|
| 231 |
+
python training\plot_results.py outputs_v3\reward_curve.csv
|
| 232 |
```
|
| 233 |
|
| 234 |
+
## Hugging Face Space
|
| 235 |
|
| 236 |
+
This repo is designed to be hosted as an OpenEnv FastAPI Space.
|
| 237 |
|
| 238 |
```powershell
|
| 239 |
openenv push --repo-id <your-hf-username>/adapt-dsa-tutor
|
| 240 |
```
|
| 241 |
|
| 242 |
+
## Submission checklist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
- OpenEnv environment with `Environment`, `reset`, `step`, and `state`
|
| 245 |
+
- valid `openenv.yaml`
|
| 246 |
+
- Hugging Face Space deployment
|
| 247 |
+
- GRPO training script with Unsloth + TRL
|
| 248 |
+
- reward and pass-rate plots from a real run
|
| 249 |
+
- baseline vs trained evaluation summary
|
| 250 |
+
- Colab notebook link for reproducibility
|
| 251 |
|
| 252 |
+
## Links
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
- HuggingFace Space URL: [HuggingFace Space URL]
|
| 255 |
+
- Colab Training Notebook: [Colab Training Notebook]
|
| 256 |
+
- HF Blog Post: [HF Blog Post]
|
| 257 |
+
- YouTube Demo: [YouTube Demo]
|
client.py
CHANGED
|
@@ -11,6 +11,7 @@ class AdaptEnvClient:
|
|
| 11 |
def __init__(self, base_url: str = "http://localhost:7860") -> None:
|
| 12 |
self.base_url = base_url.rstrip("/")
|
| 13 |
self._client = httpx.Client(base_url=self.base_url, timeout=30.0)
|
|
|
|
| 14 |
|
| 15 |
def close(self) -> None:
|
| 16 |
self._client.close()
|
|
@@ -18,15 +19,24 @@ class AdaptEnvClient:
|
|
| 18 |
def reset(self, **params: Any) -> dict[str, Any]:
|
| 19 |
response = self._client.post("/reset", json=params)
|
| 20 |
response.raise_for_status()
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def step(self, code: str) -> dict[str, Any]:
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
response.raise_for_status()
|
| 26 |
return response.json()
|
| 27 |
|
| 28 |
def state(self) -> dict[str, Any]:
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
response.raise_for_status()
|
| 31 |
return response.json()
|
| 32 |
|
|
|
|
| 11 |
def __init__(self, base_url: str = "http://localhost:7860") -> None:
|
| 12 |
self.base_url = base_url.rstrip("/")
|
| 13 |
self._client = httpx.Client(base_url=self.base_url, timeout=30.0)
|
| 14 |
+
self.session_id: str | None = None
|
| 15 |
|
| 16 |
def close(self) -> None:
|
| 17 |
self._client.close()
|
|
|
|
| 19 |
def reset(self, **params: Any) -> dict[str, Any]:
|
| 20 |
response = self._client.post("/reset", json=params)
|
| 21 |
response.raise_for_status()
|
| 22 |
+
payload = response.json()
|
| 23 |
+
self.session_id = payload.get("session_id")
|
| 24 |
+
return payload
|
| 25 |
|
| 26 |
def step(self, code: str) -> dict[str, Any]:
|
| 27 |
+
if not self.session_id:
|
| 28 |
+
raise RuntimeError("Call reset() before step() so the client has a session_id.")
|
| 29 |
+
response = self._client.post(
|
| 30 |
+
"/step",
|
| 31 |
+
json=AdaptAction(session_id=self.session_id, code=code).model_dump(),
|
| 32 |
+
)
|
| 33 |
response.raise_for_status()
|
| 34 |
return response.json()
|
| 35 |
|
| 36 |
def state(self) -> dict[str, Any]:
|
| 37 |
+
if not self.session_id:
|
| 38 |
+
raise RuntimeError("Call reset() before state() so the client has a session_id.")
|
| 39 |
+
response = self._client.get("/state", params={"session_id": self.session_id})
|
| 40 |
response.raise_for_status()
|
| 41 |
return response.json()
|
| 42 |
|
env/adapt_env.py
CHANGED
|
@@ -6,6 +6,7 @@ from uuid import uuid4
|
|
| 6 |
|
| 7 |
from env.generator import DIFFICULTY_LABELS, GeneratorAgent, generator_reward, validate_problem
|
| 8 |
from models import AdaptAction, AdaptObservation, AdaptState
|
|
|
|
| 9 |
|
| 10 |
try:
|
| 11 |
from openenv.core.env_server.interfaces import Environment
|
|
@@ -22,6 +23,7 @@ except ImportError:
|
|
| 22 |
|
| 23 |
|
| 24 |
FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
@@ -31,14 +33,16 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 31 |
self,
|
| 32 |
generator: GeneratorAgent | None = None,
|
| 33 |
generator_mode: str = "heuristic",
|
|
|
|
| 34 |
) -> None:
|
| 35 |
super().__init__()
|
| 36 |
self.generator = generator or GeneratorAgent()
|
| 37 |
self.generator_mode = generator_mode
|
|
|
|
| 38 |
self.problem: dict[str, Any] = {}
|
| 39 |
-
self.test_cases: list[dict[str,
|
| 40 |
self.last_results: list[dict[str, Any]] = []
|
| 41 |
-
self.max_history =
|
| 42 |
self.min_difficulty = 1
|
| 43 |
self.max_difficulty = 3
|
| 44 |
self.difficulty: int = 1
|
|
@@ -49,7 +53,17 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 49 |
"problem_signatures": [],
|
| 50 |
"episode_index": 0,
|
| 51 |
}
|
| 52 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def reset(
|
| 55 |
self,
|
|
@@ -59,35 +73,52 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 59 |
difficulty: str | None = None,
|
| 60 |
generated_problem: dict[str, Any] | None = None,
|
| 61 |
generator_mode: str | None = None,
|
|
|
|
|
|
|
| 62 |
**_: Any,
|
| 63 |
) -> AdaptObservation:
|
| 64 |
del seed
|
|
|
|
|
|
|
|
|
|
| 65 |
if generator_mode is not None:
|
| 66 |
self.generator_mode = generator_mode
|
| 67 |
if difficulty is not None:
|
| 68 |
self.difficulty = self._difficulty_to_tier(difficulty)
|
| 69 |
-
elif
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
self.problem = self._load_problem(
|
| 73 |
generated_problem=generated_problem,
|
| 74 |
problem_id=problem_id,
|
|
|
|
| 75 |
)
|
| 76 |
self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]]
|
| 77 |
self.last_results = []
|
|
|
|
|
|
|
|
|
|
| 78 |
self._state = AdaptState(
|
|
|
|
| 79 |
episode_id=episode_id or str(uuid4()),
|
| 80 |
step_count=0,
|
| 81 |
problem_id=self.problem["problem_id"],
|
| 82 |
problem_type=self.problem.get("problem_type", ""),
|
| 83 |
difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
| 84 |
generator_mode=self.generator_mode,
|
|
|
|
| 85 |
generated_problem=self._public_problem_view(),
|
|
|
|
| 86 |
)
|
| 87 |
return self._build_observation(
|
| 88 |
reward=0.0,
|
| 89 |
done=False,
|
| 90 |
-
feedback=
|
|
|
|
|
|
|
|
|
|
| 91 |
execution_status="ready",
|
| 92 |
)
|
| 93 |
|
|
@@ -98,59 +129,130 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 98 |
**_: Any,
|
| 99 |
) -> AdaptObservation:
|
| 100 |
del timeout_s
|
|
|
|
| 101 |
if not self.problem:
|
| 102 |
-
self.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
self._state.step_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
syntax_ok, syntax_error = self._check_syntax(action.code)
|
| 106 |
if not syntax_ok:
|
|
|
|
| 107 |
observation = self._build_observation(
|
| 108 |
reward=0.0,
|
| 109 |
-
done=
|
| 110 |
-
feedback=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
syntax_valid=False,
|
| 112 |
execution_status="syntax_error",
|
| 113 |
-
reward_components={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
return observation
|
| 117 |
|
| 118 |
safety_ok, safety_error = self._check_safety(action.code)
|
| 119 |
if not safety_ok:
|
|
|
|
| 120 |
observation = self._build_observation(
|
| 121 |
reward=0.0,
|
| 122 |
-
done=
|
| 123 |
-
feedback=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
syntax_valid=True,
|
| 125 |
execution_status="safety_violation",
|
| 126 |
-
reward_components={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
return observation
|
| 130 |
|
| 131 |
-
|
| 132 |
self.last_results = list(metadata.get("results", []))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
observation = self._build_observation(
|
| 134 |
reward=reward,
|
| 135 |
-
done=
|
| 136 |
-
feedback=
|
| 137 |
-
pass_rate=
|
| 138 |
-
visible_pass_rate=
|
| 139 |
-
hidden_pass_rate=
|
| 140 |
syntax_valid=True,
|
| 141 |
-
execution_status=
|
| 142 |
timeout_count=int(metadata.get("timeout_count", 0)),
|
| 143 |
runtime_error_count=int(metadata.get("runtime_error_count", 0)),
|
| 144 |
invalid_output_count=int(metadata.get("invalid_output_count", 0)),
|
| 145 |
wrong_answer_count=int(metadata.get("wrong_answer_count", 0)),
|
| 146 |
format_compliance=float(metadata.get("format_compliance", 0.0)),
|
| 147 |
-
reward_components=
|
| 148 |
-
key: round(float(value), 4)
|
| 149 |
-
for key, value in dict(metadata.get("reward_components", {})).items()
|
| 150 |
-
},
|
| 151 |
generator_reward_signal=float(metadata.get("generator_reward", 0.0)),
|
| 152 |
)
|
| 153 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 154 |
return observation
|
| 155 |
|
| 156 |
@property
|
|
@@ -177,23 +279,26 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 177 |
) -> AdaptObservation:
|
| 178 |
public_problem = self._public_problem_view()
|
| 179 |
return AdaptObservation(
|
|
|
|
| 180 |
problem_id=self.problem.get("problem_id", ""),
|
| 181 |
problem_type=self.problem.get("problem_type", ""),
|
| 182 |
difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
|
|
|
|
|
|
| 183 |
problem=public_problem.get("problem", ""),
|
| 184 |
input_format=public_problem.get("input_format", ""),
|
| 185 |
constraints=public_problem.get("constraints", ""),
|
| 186 |
feedback=feedback,
|
| 187 |
-
pass_rate=pass_rate,
|
| 188 |
-
visible_pass_rate=visible_pass_rate,
|
| 189 |
-
hidden_pass_rate=hidden_pass_rate,
|
| 190 |
syntax_valid=syntax_valid,
|
| 191 |
execution_status=execution_status,
|
| 192 |
timeout_count=timeout_count,
|
| 193 |
runtime_error_count=runtime_error_count,
|
| 194 |
invalid_output_count=invalid_output_count,
|
| 195 |
wrong_answer_count=wrong_answer_count,
|
| 196 |
-
format_compliance=format_compliance,
|
| 197 |
reward_components=reward_components or {},
|
| 198 |
generator_reward_signal=round(float(generator_reward_signal), 4),
|
| 199 |
reward=round(max(0.0, min(1.0, reward)), 4),
|
|
@@ -204,15 +309,22 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 204 |
self,
|
| 205 |
generated_problem: dict[str, Any] | None,
|
| 206 |
problem_id: str | None,
|
|
|
|
| 207 |
) -> dict[str, Any]:
|
| 208 |
-
candidate = generated_problem or self.generator.
|
| 209 |
self.difficulty,
|
| 210 |
self.history,
|
| 211 |
problem_id=problem_id,
|
|
|
|
| 212 |
)
|
| 213 |
if validate_problem(candidate):
|
| 214 |
return candidate
|
| 215 |
-
fallback = self.generator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
if not validate_problem(fallback):
|
| 217 |
raise ValueError("Generator produced an invalid problem twice in a row.")
|
| 218 |
return fallback
|
|
@@ -221,53 +333,155 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 221 |
try:
|
| 222 |
from verifier.verifier import verify
|
| 223 |
except ImportError as exc:
|
| 224 |
-
return 0.0, {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
try:
|
| 227 |
reward, metadata = verify(code, self.test_cases)
|
| 228 |
except Exception as exc:
|
| 229 |
-
return 0.0, {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
metadata = dict(metadata or {})
|
| 232 |
diversity_bonus = self._diversity_bonus(self.problem.get("problem_type", ""))
|
| 233 |
validity_bonus = float(self.problem.get("validity_bonus", 0.0))
|
|
|
|
| 234 |
metadata["generator_reward"] = generator_reward(
|
| 235 |
-
|
| 236 |
diversity_bonus=diversity_bonus,
|
| 237 |
validity_bonus=validity_bonus,
|
| 238 |
)
|
| 239 |
return float(reward), metadata
|
| 240 |
|
| 241 |
-
def
|
| 242 |
-
self
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
def
|
| 246 |
-
self
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
def _record_metrics(self, observation: AdaptObservation) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
self._state.last_reward = float(observation.reward or 0.0)
|
| 259 |
self._state.last_pass_rate = observation.pass_rate
|
| 260 |
self._state.last_feedback = observation.feedback
|
|
|
|
| 261 |
self._state.generator_reward_signal = observation.generator_reward_signal
|
| 262 |
-
self._state.history = {
|
| 263 |
-
|
| 264 |
-
"problem_types": list(self.history["problem_types"]),
|
| 265 |
-
"generator_rewards": list(self.history["generator_rewards"]),
|
| 266 |
-
}
|
| 267 |
self._state.recent_metrics = {
|
| 268 |
"difficulty_tier": self.difficulty,
|
| 269 |
"difficulty_label": self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
| 270 |
-
"
|
| 271 |
"pass_rate": observation.pass_rate,
|
| 272 |
"execution_status": observation.execution_status,
|
| 273 |
"timeout_count": observation.timeout_count,
|
|
@@ -278,27 +492,47 @@ class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
| 278 |
"reward_components": dict(observation.reward_components),
|
| 279 |
}
|
| 280 |
|
| 281 |
-
def
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
def _public_problem_view(self) -> dict[str, str]:
|
| 293 |
visible = dict(self.problem.get("visible_problem", {}))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
return {
|
| 295 |
-
"problem":
|
| 296 |
"input_format": visible.get("input_format", self.problem.get("input_format", "")),
|
| 297 |
"constraints": visible.get("constraints", self.problem.get("constraints", "")),
|
| 298 |
}
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
def _diversity_bonus(self, problem_type: str) -> float:
|
| 301 |
-
recent_types = list(self.history.get("problem_types", [])[-
|
| 302 |
if not recent_types:
|
| 303 |
return 0.1
|
| 304 |
if problem_type in recent_types:
|
|
|
|
| 6 |
|
| 7 |
from env.generator import DIFFICULTY_LABELS, GeneratorAgent, generator_reward, validate_problem
|
| 8 |
from models import AdaptAction, AdaptObservation, AdaptState
|
| 9 |
+
from verifier.metrics import compute_reward
|
| 10 |
|
| 11 |
try:
|
| 12 |
from openenv.core.env_server.interfaces import Environment
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
|
| 26 |
+
MAX_STEPS_PER_EPISODE = 3
|
| 27 |
|
| 28 |
|
| 29 |
class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
|
|
|
|
| 33 |
self,
|
| 34 |
generator: GeneratorAgent | None = None,
|
| 35 |
generator_mode: str = "heuristic",
|
| 36 |
+
session_id: str | None = None,
|
| 37 |
) -> None:
|
| 38 |
super().__init__()
|
| 39 |
self.generator = generator or GeneratorAgent()
|
| 40 |
self.generator_mode = generator_mode
|
| 41 |
+
self.session_id = session_id or str(uuid4())
|
| 42 |
self.problem: dict[str, Any] = {}
|
| 43 |
+
self.test_cases: list[dict[str, Any]] = []
|
| 44 |
self.last_results: list[dict[str, Any]] = []
|
| 45 |
+
self.max_history = 50
|
| 46 |
self.min_difficulty = 1
|
| 47 |
self.max_difficulty = 3
|
| 48 |
self.difficulty: int = 1
|
|
|
|
| 53 |
"problem_signatures": [],
|
| 54 |
"episode_index": 0,
|
| 55 |
}
|
| 56 |
+
self.attempt_history: list[dict[str, Any]] = []
|
| 57 |
+
self.previous_execution_status = "ready"
|
| 58 |
+
self.episode_done = False
|
| 59 |
+
self._state = AdaptState(
|
| 60 |
+
session_id=self.session_id,
|
| 61 |
+
episode_id=str(uuid4()),
|
| 62 |
+
step_count=0,
|
| 63 |
+
generator_mode=self.generator_mode,
|
| 64 |
+
max_steps=MAX_STEPS_PER_EPISODE,
|
| 65 |
+
history={"attempts": []},
|
| 66 |
+
)
|
| 67 |
|
| 68 |
def reset(
|
| 69 |
self,
|
|
|
|
| 73 |
difficulty: str | None = None,
|
| 74 |
generated_problem: dict[str, Any] | None = None,
|
| 75 |
generator_mode: str | None = None,
|
| 76 |
+
session_id: str | None = None,
|
| 77 |
+
family_weights: dict[str, float] | None = None,
|
| 78 |
**_: Any,
|
| 79 |
) -> AdaptObservation:
|
| 80 |
del seed
|
| 81 |
+
|
| 82 |
+
if session_id:
|
| 83 |
+
self.session_id = session_id
|
| 84 |
if generator_mode is not None:
|
| 85 |
self.generator_mode = generator_mode
|
| 86 |
if difficulty is not None:
|
| 87 |
self.difficulty = self._difficulty_to_tier(difficulty)
|
| 88 |
+
elif generated_problem is not None:
|
| 89 |
+
generated_label = str(generated_problem.get("difficulty_label", "")).strip().lower()
|
| 90 |
+
if generated_label:
|
| 91 |
+
self.difficulty = self._difficulty_to_tier(generated_label)
|
| 92 |
|
| 93 |
self.problem = self._load_problem(
|
| 94 |
generated_problem=generated_problem,
|
| 95 |
problem_id=problem_id,
|
| 96 |
+
family_weights=family_weights,
|
| 97 |
)
|
| 98 |
self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]]
|
| 99 |
self.last_results = []
|
| 100 |
+
self.attempt_history = []
|
| 101 |
+
self.previous_execution_status = "ready"
|
| 102 |
+
self.episode_done = False
|
| 103 |
self._state = AdaptState(
|
| 104 |
+
session_id=self.session_id,
|
| 105 |
episode_id=episode_id or str(uuid4()),
|
| 106 |
step_count=0,
|
| 107 |
problem_id=self.problem["problem_id"],
|
| 108 |
problem_type=self.problem.get("problem_type", ""),
|
| 109 |
difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
| 110 |
generator_mode=self.generator_mode,
|
| 111 |
+
max_steps=MAX_STEPS_PER_EPISODE,
|
| 112 |
generated_problem=self._public_problem_view(),
|
| 113 |
+
history={"attempts": []},
|
| 114 |
)
|
| 115 |
return self._build_observation(
|
| 116 |
reward=0.0,
|
| 117 |
done=False,
|
| 118 |
+
feedback=(
|
| 119 |
+
"You have up to 3 attempts. Submit Python code that reads stdin and prints the required answer. "
|
| 120 |
+
"Use the examples to infer the expected behavior."
|
| 121 |
+
),
|
| 122 |
execution_status="ready",
|
| 123 |
)
|
| 124 |
|
|
|
|
| 129 |
**_: Any,
|
| 130 |
) -> AdaptObservation:
|
| 131 |
del timeout_s
|
| 132 |
+
|
| 133 |
if not self.problem:
|
| 134 |
+
self.reset(session_id=action.session_id or self.session_id)
|
| 135 |
+
|
| 136 |
+
if self.episode_done:
|
| 137 |
+
return self._build_observation(
|
| 138 |
+
reward=float(self._state.last_reward or 0.0),
|
| 139 |
+
done=True,
|
| 140 |
+
feedback="This episode is finished. Call reset() to start a new problem.",
|
| 141 |
+
pass_rate=float(self._state.last_pass_rate or 0.0),
|
| 142 |
+
visible_pass_rate=float(self._state.recent_metrics.get("visible_pass_rate", 0.0)),
|
| 143 |
+
hidden_pass_rate=float(self._state.last_pass_rate or 0.0),
|
| 144 |
+
syntax_valid=self._state.last_execution_status != "syntax_error",
|
| 145 |
+
execution_status=self._state.last_execution_status or "completed",
|
| 146 |
+
timeout_count=int(self._state.recent_metrics.get("timeout_count", 0)),
|
| 147 |
+
runtime_error_count=int(self._state.recent_metrics.get("runtime_error_count", 0)),
|
| 148 |
+
invalid_output_count=int(self._state.recent_metrics.get("invalid_output_count", 0)),
|
| 149 |
+
wrong_answer_count=int(self._state.recent_metrics.get("wrong_answer_count", 0)),
|
| 150 |
+
format_compliance=float(self._state.recent_metrics.get("format_compliance", 0.0)),
|
| 151 |
+
reward_components=dict(self._state.recent_metrics.get("reward_components", {})),
|
| 152 |
+
generator_reward_signal=float(self._state.generator_reward_signal or 0.0),
|
| 153 |
+
)
|
| 154 |
|
| 155 |
self._state.step_count += 1
|
| 156 |
+
attempt_number = self._state.step_count
|
| 157 |
+
previous_status = self.previous_execution_status
|
| 158 |
+
previous_pass_rate = float(self._state.last_pass_rate or 0.0)
|
| 159 |
+
|
| 160 |
syntax_ok, syntax_error = self._check_syntax(action.code)
|
| 161 |
if not syntax_ok:
|
| 162 |
+
done = attempt_number >= MAX_STEPS_PER_EPISODE
|
| 163 |
observation = self._build_observation(
|
| 164 |
reward=0.0,
|
| 165 |
+
done=done,
|
| 166 |
+
feedback=self._format_static_feedback(
|
| 167 |
+
attempt_number=attempt_number,
|
| 168 |
+
previous_status=previous_status,
|
| 169 |
+
execution_status="syntax_error",
|
| 170 |
+
details=f"Syntax error: {syntax_error}",
|
| 171 |
+
),
|
| 172 |
syntax_valid=False,
|
| 173 |
execution_status="syntax_error",
|
| 174 |
+
reward_components={
|
| 175 |
+
"correctness": 0.0,
|
| 176 |
+
"step_discount": 1.0 if attempt_number == 1 else (0.85 if attempt_number == 2 else 0.70),
|
| 177 |
+
"progress_delta": 0.0,
|
| 178 |
+
},
|
| 179 |
)
|
| 180 |
+
self.last_results = []
|
| 181 |
+
self.previous_execution_status = observation.execution_status
|
| 182 |
+
self._record_metrics(observation)
|
| 183 |
+
if done:
|
| 184 |
+
self._finalize_episode(observation)
|
| 185 |
return observation
|
| 186 |
|
| 187 |
safety_ok, safety_error = self._check_safety(action.code)
|
| 188 |
if not safety_ok:
|
| 189 |
+
done = attempt_number >= MAX_STEPS_PER_EPISODE
|
| 190 |
observation = self._build_observation(
|
| 191 |
reward=0.0,
|
| 192 |
+
done=done,
|
| 193 |
+
feedback=self._format_static_feedback(
|
| 194 |
+
attempt_number=attempt_number,
|
| 195 |
+
previous_status=previous_status,
|
| 196 |
+
execution_status="safety_violation",
|
| 197 |
+
details=safety_error,
|
| 198 |
+
),
|
| 199 |
syntax_valid=True,
|
| 200 |
execution_status="safety_violation",
|
| 201 |
+
reward_components={
|
| 202 |
+
"correctness": 0.0,
|
| 203 |
+
"step_discount": 1.0 if attempt_number == 1 else (0.85 if attempt_number == 2 else 0.70),
|
| 204 |
+
"progress_delta": 0.0,
|
| 205 |
+
},
|
| 206 |
)
|
| 207 |
+
self.last_results = []
|
| 208 |
+
self.previous_execution_status = observation.execution_status
|
| 209 |
+
self._record_metrics(observation)
|
| 210 |
+
if done:
|
| 211 |
+
self._finalize_episode(observation)
|
| 212 |
return observation
|
| 213 |
|
| 214 |
+
_, metadata = self._verify_submission(action.code)
|
| 215 |
self.last_results = list(metadata.get("results", []))
|
| 216 |
+
hidden_pass_rate = float(metadata.get("hidden_pass_rate", metadata.get("pass_rate", 0.0)))
|
| 217 |
+
visible_pass_rate = float(metadata.get("visible_pass_rate", 0.0))
|
| 218 |
+
execution_status = str(metadata.get("execution_status", "completed"))
|
| 219 |
+
done = hidden_pass_rate == 1.0 or attempt_number >= MAX_STEPS_PER_EPISODE
|
| 220 |
+
reward, reward_components = self._shape_reward(
|
| 221 |
+
pass_rate=hidden_pass_rate,
|
| 222 |
+
step_number=attempt_number,
|
| 223 |
+
execution_status=execution_status,
|
| 224 |
+
previous_pass_rate=previous_pass_rate,
|
| 225 |
+
done=done,
|
| 226 |
+
)
|
| 227 |
+
feedback = self._format_feedback(
|
| 228 |
+
results=self.last_results,
|
| 229 |
+
attempt_number=attempt_number,
|
| 230 |
+
previous_status=previous_status,
|
| 231 |
+
execution_status=execution_status,
|
| 232 |
+
hidden_pass_rate=hidden_pass_rate,
|
| 233 |
+
visible_pass_rate=visible_pass_rate,
|
| 234 |
+
)
|
| 235 |
observation = self._build_observation(
|
| 236 |
reward=reward,
|
| 237 |
+
done=done,
|
| 238 |
+
feedback=feedback,
|
| 239 |
+
pass_rate=hidden_pass_rate,
|
| 240 |
+
visible_pass_rate=visible_pass_rate,
|
| 241 |
+
hidden_pass_rate=hidden_pass_rate,
|
| 242 |
syntax_valid=True,
|
| 243 |
+
execution_status=execution_status,
|
| 244 |
timeout_count=int(metadata.get("timeout_count", 0)),
|
| 245 |
runtime_error_count=int(metadata.get("runtime_error_count", 0)),
|
| 246 |
invalid_output_count=int(metadata.get("invalid_output_count", 0)),
|
| 247 |
wrong_answer_count=int(metadata.get("wrong_answer_count", 0)),
|
| 248 |
format_compliance=float(metadata.get("format_compliance", 0.0)),
|
| 249 |
+
reward_components=reward_components,
|
|
|
|
|
|
|
|
|
|
| 250 |
generator_reward_signal=float(metadata.get("generator_reward", 0.0)),
|
| 251 |
)
|
| 252 |
+
self.previous_execution_status = observation.execution_status
|
| 253 |
+
self._record_metrics(observation)
|
| 254 |
+
if done:
|
| 255 |
+
self._finalize_episode(observation)
|
| 256 |
return observation
|
| 257 |
|
| 258 |
@property
|
|
|
|
| 279 |
) -> AdaptObservation:
|
| 280 |
public_problem = self._public_problem_view()
|
| 281 |
return AdaptObservation(
|
| 282 |
+
session_id=self.session_id,
|
| 283 |
problem_id=self.problem.get("problem_id", ""),
|
| 284 |
problem_type=self.problem.get("problem_type", ""),
|
| 285 |
difficulty=self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
| 286 |
+
attempt_number=self._state.step_count,
|
| 287 |
+
max_steps=MAX_STEPS_PER_EPISODE,
|
| 288 |
problem=public_problem.get("problem", ""),
|
| 289 |
input_format=public_problem.get("input_format", ""),
|
| 290 |
constraints=public_problem.get("constraints", ""),
|
| 291 |
feedback=feedback,
|
| 292 |
+
pass_rate=round(float(pass_rate), 4),
|
| 293 |
+
visible_pass_rate=round(float(visible_pass_rate), 4),
|
| 294 |
+
hidden_pass_rate=round(float(hidden_pass_rate), 4),
|
| 295 |
syntax_valid=syntax_valid,
|
| 296 |
execution_status=execution_status,
|
| 297 |
timeout_count=timeout_count,
|
| 298 |
runtime_error_count=runtime_error_count,
|
| 299 |
invalid_output_count=invalid_output_count,
|
| 300 |
wrong_answer_count=wrong_answer_count,
|
| 301 |
+
format_compliance=round(float(format_compliance), 4),
|
| 302 |
reward_components=reward_components or {},
|
| 303 |
generator_reward_signal=round(float(generator_reward_signal), 4),
|
| 304 |
reward=round(max(0.0, min(1.0, reward)), 4),
|
|
|
|
| 309 |
self,
|
| 310 |
generated_problem: dict[str, Any] | None,
|
| 311 |
problem_id: str | None,
|
| 312 |
+
family_weights: dict[str, float] | None,
|
| 313 |
) -> dict[str, Any]:
|
| 314 |
+
candidate = generated_problem or self.generator.generate_problem(
|
| 315 |
self.difficulty,
|
| 316 |
self.history,
|
| 317 |
problem_id=problem_id,
|
| 318 |
+
family_weights=family_weights,
|
| 319 |
)
|
| 320 |
if validate_problem(candidate):
|
| 321 |
return candidate
|
| 322 |
+
fallback = self.generator.generate_problem(
|
| 323 |
+
self.difficulty,
|
| 324 |
+
self.history,
|
| 325 |
+
problem_id=problem_id,
|
| 326 |
+
family_weights=family_weights,
|
| 327 |
+
)
|
| 328 |
if not validate_problem(fallback):
|
| 329 |
raise ValueError("Generator produced an invalid problem twice in a row.")
|
| 330 |
return fallback
|
|
|
|
| 333 |
try:
|
| 334 |
from verifier.verifier import verify
|
| 335 |
except ImportError as exc:
|
| 336 |
+
return 0.0, {
|
| 337 |
+
"feedback": f"Verifier unavailable: {exc}",
|
| 338 |
+
"execution_status": "verifier_error",
|
| 339 |
+
"results": [],
|
| 340 |
+
}
|
| 341 |
|
| 342 |
try:
|
| 343 |
reward, metadata = verify(code, self.test_cases)
|
| 344 |
except Exception as exc:
|
| 345 |
+
return 0.0, {
|
| 346 |
+
"feedback": f"Verifier crashed: {exc}",
|
| 347 |
+
"execution_status": "verifier_error",
|
| 348 |
+
"results": [],
|
| 349 |
+
}
|
| 350 |
|
| 351 |
metadata = dict(metadata or {})
|
| 352 |
diversity_bonus = self._diversity_bonus(self.problem.get("problem_type", ""))
|
| 353 |
validity_bonus = float(self.problem.get("validity_bonus", 0.0))
|
| 354 |
+
hidden_pass_rate = float(metadata.get("hidden_pass_rate", metadata.get("pass_rate", 0.0)))
|
| 355 |
metadata["generator_reward"] = generator_reward(
|
| 356 |
+
hidden_pass_rate,
|
| 357 |
diversity_bonus=diversity_bonus,
|
| 358 |
validity_bonus=validity_bonus,
|
| 359 |
)
|
| 360 |
return float(reward), metadata
|
| 361 |
|
| 362 |
+
def _shape_reward(
|
| 363 |
+
self,
|
| 364 |
+
pass_rate: float,
|
| 365 |
+
step_number: int,
|
| 366 |
+
execution_status: str,
|
| 367 |
+
previous_pass_rate: float,
|
| 368 |
+
done: bool,
|
| 369 |
+
) -> tuple[float, dict[str, float]]:
|
| 370 |
+
step_discount = 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70)
|
| 371 |
+
progress_delta = max(0.0, float(pass_rate) - float(previous_pass_rate))
|
| 372 |
+
|
| 373 |
+
if execution_status in {"timeout", "syntax_error", "safety_violation"}:
|
| 374 |
+
reward = 0.0
|
| 375 |
+
elif pass_rate == 1.0:
|
| 376 |
+
reward = compute_reward(
|
| 377 |
+
pass_rate=pass_rate,
|
| 378 |
+
step_number=step_number,
|
| 379 |
+
execution_status=execution_status,
|
| 380 |
+
format_compliance=0.0,
|
| 381 |
+
)
|
| 382 |
+
elif done:
|
| 383 |
+
reward = 0.0
|
| 384 |
+
else:
|
| 385 |
+
reward = round(0.1 * progress_delta, 4)
|
| 386 |
+
|
| 387 |
+
return reward, {
|
| 388 |
+
"correctness": round(float(pass_rate), 4),
|
| 389 |
+
"step_discount": round(step_discount, 4),
|
| 390 |
+
"progress_delta": round(progress_delta, 4),
|
| 391 |
+
"reward": round(float(reward), 4),
|
| 392 |
+
}
|
| 393 |
|
| 394 |
+
def _format_feedback(
|
| 395 |
+
self,
|
| 396 |
+
results: list[dict[str, Any]],
|
| 397 |
+
attempt_number: int,
|
| 398 |
+
previous_status: str,
|
| 399 |
+
execution_status: str,
|
| 400 |
+
hidden_pass_rate: float,
|
| 401 |
+
visible_pass_rate: float,
|
| 402 |
+
) -> str:
|
| 403 |
+
lines = [
|
| 404 |
+
f"Attempt {attempt_number}/{MAX_STEPS_PER_EPISODE}.",
|
| 405 |
+
f"Previous attempt status: {previous_status}.",
|
| 406 |
+
f"Current execution status: {execution_status}.",
|
| 407 |
+
f"Hidden pass rate: {hidden_pass_rate:.2f}. Visible pass rate: {visible_pass_rate:.2f}.",
|
| 408 |
+
]
|
| 409 |
+
|
| 410 |
+
failed_tests = self._summarize_failed_tests(results)
|
| 411 |
+
if failed_tests:
|
| 412 |
+
lines.append("Failed tests:")
|
| 413 |
+
lines.extend(failed_tests)
|
| 414 |
+
elif hidden_pass_rate == 1.0:
|
| 415 |
+
lines.append("All hidden tests passed.")
|
| 416 |
+
else:
|
| 417 |
+
lines.append("No failing test details were available.")
|
| 418 |
+
|
| 419 |
+
return "\n".join(lines)
|
| 420 |
+
|
| 421 |
+
def _format_static_feedback(
|
| 422 |
+
self,
|
| 423 |
+
attempt_number: int,
|
| 424 |
+
previous_status: str,
|
| 425 |
+
execution_status: str,
|
| 426 |
+
details: str,
|
| 427 |
+
) -> str:
|
| 428 |
+
return "\n".join(
|
| 429 |
+
[
|
| 430 |
+
f"Attempt {attempt_number}/{MAX_STEPS_PER_EPISODE}.",
|
| 431 |
+
f"Previous attempt status: {previous_status}.",
|
| 432 |
+
f"Current execution status: {execution_status}.",
|
| 433 |
+
details,
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
|
| 437 |
+
def _summarize_failed_tests(self, results: list[dict[str, Any]]) -> list[str]:
|
| 438 |
+
summaries: list[str] = []
|
| 439 |
+
for result in results:
|
| 440 |
+
if result.get("passed", False):
|
| 441 |
+
continue
|
| 442 |
+
visibility = str(result.get("visibility", "hidden"))
|
| 443 |
+
label = f"{visibility.title()} test #{int(result.get('index', 0)) + 1}"
|
| 444 |
+
status = str(result.get("status", "unknown"))
|
| 445 |
+
if visibility == "visible":
|
| 446 |
+
actual = str(result.get("stdout", "")).strip()
|
| 447 |
+
expected = str(result.get("expected", "")).strip()
|
| 448 |
+
details = []
|
| 449 |
+
if expected:
|
| 450 |
+
details.append(f"expected={expected}")
|
| 451 |
+
if actual:
|
| 452 |
+
details.append(f"got={actual}")
|
| 453 |
+
if result.get("stderr"):
|
| 454 |
+
details.append("stderr_present")
|
| 455 |
+
if details:
|
| 456 |
+
summaries.append(f"- {label}: {status} ({', '.join(details)})")
|
| 457 |
+
else:
|
| 458 |
+
summaries.append(f"- {label}: {status}")
|
| 459 |
+
else:
|
| 460 |
+
summaries.append(f"- {label}: {status}")
|
| 461 |
+
return summaries
|
| 462 |
|
| 463 |
def _record_metrics(self, observation: AdaptObservation) -> None:
|
| 464 |
+
attempt_record = {
|
| 465 |
+
"attempt_number": observation.attempt_number,
|
| 466 |
+
"reward": float(observation.reward or 0.0),
|
| 467 |
+
"pass_rate": float(observation.pass_rate),
|
| 468 |
+
"visible_pass_rate": float(observation.visible_pass_rate),
|
| 469 |
+
"execution_status": observation.execution_status,
|
| 470 |
+
"feedback": observation.feedback,
|
| 471 |
+
"done": bool(observation.done),
|
| 472 |
+
}
|
| 473 |
+
self.attempt_history.append(attempt_record)
|
| 474 |
self._state.last_reward = float(observation.reward or 0.0)
|
| 475 |
self._state.last_pass_rate = observation.pass_rate
|
| 476 |
self._state.last_feedback = observation.feedback
|
| 477 |
+
self._state.last_execution_status = observation.execution_status
|
| 478 |
self._state.generator_reward_signal = observation.generator_reward_signal
|
| 479 |
+
self._state.history = {"attempts": list(self.attempt_history)}
|
| 480 |
+
self._state.generated_problem = self._public_problem_view()
|
|
|
|
|
|
|
|
|
|
| 481 |
self._state.recent_metrics = {
|
| 482 |
"difficulty_tier": self.difficulty,
|
| 483 |
"difficulty_label": self.problem.get("difficulty_label", self._tier_to_difficulty(self.difficulty)),
|
| 484 |
+
"visible_pass_rate": observation.visible_pass_rate,
|
| 485 |
"pass_rate": observation.pass_rate,
|
| 486 |
"execution_status": observation.execution_status,
|
| 487 |
"timeout_count": observation.timeout_count,
|
|
|
|
| 492 |
"reward_components": dict(observation.reward_components),
|
| 493 |
}
|
| 494 |
|
| 495 |
+
def _finalize_episode(self, observation: AdaptObservation) -> None:
|
| 496 |
+
self.episode_done = True
|
| 497 |
+
self._update_history(observation.pass_rate, observation.generator_reward_signal)
|
| 498 |
+
|
| 499 |
+
def _update_history(self, pass_rate: float, generator_signal: float) -> None:
|
| 500 |
+
self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
|
| 501 |
+
self.history["problem_types"].append(self.problem.get("problem_type", ""))
|
| 502 |
+
self.history["problem_signatures"].append(self.problem.get("problem_id", ""))
|
| 503 |
+
self.history["generator_rewards"].append(round(float(generator_signal), 4))
|
| 504 |
+
self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
|
| 505 |
+
|
| 506 |
+
for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
|
| 507 |
+
values = self.history[key]
|
| 508 |
+
if len(values) > self.max_history:
|
| 509 |
+
del values[:-self.max_history]
|
| 510 |
|
| 511 |
def _public_problem_view(self) -> dict[str, str]:
|
| 512 |
visible = dict(self.problem.get("visible_problem", {}))
|
| 513 |
+
base_problem = visible.get("problem", self.problem.get("problem", ""))
|
| 514 |
+
examples = self._format_examples()
|
| 515 |
+
if examples:
|
| 516 |
+
base_problem = f"{base_problem}\n\nExamples:\n{examples}"
|
| 517 |
return {
|
| 518 |
+
"problem": base_problem,
|
| 519 |
"input_format": visible.get("input_format", self.problem.get("input_format", "")),
|
| 520 |
"constraints": visible.get("constraints", self.problem.get("constraints", "")),
|
| 521 |
}
|
| 522 |
|
| 523 |
+
def _format_examples(self) -> str:
|
| 524 |
+
visible_cases = [test_case for test_case in self.test_cases if test_case.get("is_visible", False)]
|
| 525 |
+
if not visible_cases:
|
| 526 |
+
return ""
|
| 527 |
+
chunks = []
|
| 528 |
+
for test_case in visible_cases:
|
| 529 |
+
chunks.append(
|
| 530 |
+
f"Input:\n{test_case['input']}Expected Output:\n{test_case['output']}\n"
|
| 531 |
+
)
|
| 532 |
+
return "\n".join(chunks).rstrip()
|
| 533 |
+
|
| 534 |
def _diversity_bonus(self, problem_type: str) -> float:
|
| 535 |
+
recent_types = list(self.history.get("problem_types", [])[-6:])
|
| 536 |
if not recent_types:
|
| 537 |
return 0.1
|
| 538 |
if problem_type in recent_types:
|
env/generator.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import hashlib
|
| 4 |
-
import math
|
| 5 |
import random
|
|
|
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import Any, Callable
|
| 8 |
|
| 9 |
-
VISIBLE_TEST_COUNT =
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass(frozen=True)
|
|
@@ -17,9 +19,9 @@ class ProblemTemplate:
|
|
| 17 |
title: str
|
| 18 |
input_format: str
|
| 19 |
constraints: str
|
| 20 |
-
statement_builder: Callable[[
|
| 21 |
solver: Callable[[str], str]
|
| 22 |
-
case_builder: Callable[[random.Random
|
| 23 |
|
| 24 |
|
| 25 |
def generator_reward(
|
|
@@ -54,14 +56,15 @@ def validate_problem(problem_dict: dict[str, Any]) -> bool:
|
|
| 54 |
return False
|
| 55 |
|
| 56 |
test_cases = problem_dict.get("test_cases")
|
| 57 |
-
if not isinstance(test_cases, list) or len(test_cases)
|
| 58 |
return False
|
| 59 |
|
| 60 |
seen_inputs: set[str] = set()
|
| 61 |
distinct_outputs: set[str] = set()
|
| 62 |
visible_count = 0
|
|
|
|
| 63 |
|
| 64 |
-
for test_case in test_cases:
|
| 65 |
if not isinstance(test_case, dict):
|
| 66 |
return False
|
| 67 |
|
|
@@ -76,12 +79,26 @@ def validate_problem(problem_dict: dict[str, Any]) -> bool:
|
|
| 76 |
return False
|
| 77 |
seen_inputs.add(raw_input)
|
| 78 |
distinct_outputs.add(raw_output.strip())
|
| 79 |
-
visible_count += 1 if is_visible else 0
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
return False
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
return False
|
| 86 |
|
| 87 |
return True
|
|
@@ -94,46 +111,46 @@ class GeneratorAgent:
|
|
| 94 |
self.deterministic = deterministic
|
| 95 |
self.templates = _build_templates()
|
| 96 |
|
| 97 |
-
def
|
| 98 |
self,
|
| 99 |
difficulty_level: int | float | str,
|
| 100 |
history: dict[str, Any] | None,
|
| 101 |
problem_id: str | None = None,
|
|
|
|
| 102 |
) -> dict[str, Any]:
|
| 103 |
history = history or {}
|
| 104 |
target_tier = _difficulty_to_tier(difficulty_level)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
raw_cases = template.case_builder(rng, 0.2 + adjusted_tier * 0.25)
|
| 117 |
test_cases = [
|
| 118 |
{
|
| 119 |
"input": case_input,
|
| 120 |
"output": template.solver(case_input),
|
| 121 |
-
"is_visible":
|
| 122 |
}
|
| 123 |
-
for case_input in raw_cases
|
| 124 |
]
|
| 125 |
signature = self._problem_signature(template.problem_type, test_cases)
|
| 126 |
problem = {
|
| 127 |
"problem_id": f"{template.problem_type}_{signature[:8]}",
|
| 128 |
"problem_type": template.problem_type,
|
| 129 |
-
"difficulty": round(self._tier_to_scalar(
|
| 130 |
-
"difficulty_label": DIFFICULTY_LABELS[
|
| 131 |
-
"problem": template.statement_builder(
|
| 132 |
"input_format": template.input_format,
|
| 133 |
"constraints": template.constraints,
|
| 134 |
"test_cases": test_cases,
|
| 135 |
"visible_problem": {
|
| 136 |
-
"problem": template.statement_builder(
|
| 137 |
"input_format": template.input_format,
|
| 138 |
"constraints": template.constraints,
|
| 139 |
},
|
|
@@ -145,17 +162,19 @@ class GeneratorAgent:
|
|
| 145 |
|
| 146 |
raise ValueError(f"Unable to generate a valid problem for template {template.problem_type}")
|
| 147 |
|
| 148 |
-
def
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
def _choose_template(
|
| 161 |
self,
|
|
@@ -163,6 +182,7 @@ class GeneratorAgent:
|
|
| 163 |
history: dict[str, Any],
|
| 164 |
rng: random.Random,
|
| 165 |
forced_problem_type: str | None = None,
|
|
|
|
| 166 |
) -> ProblemTemplate:
|
| 167 |
eligible = [template for template in self.templates if template.difficulty_tier == tier]
|
| 168 |
if not eligible:
|
|
@@ -176,27 +196,37 @@ class GeneratorAgent:
|
|
| 176 |
if template.problem_type == forced_problem_type:
|
| 177 |
return template
|
| 178 |
|
| 179 |
-
recent_types = list(history.get("problem_types", [])[-
|
| 180 |
-
|
| 181 |
for template in eligible:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def _rng_for(
|
| 189 |
self,
|
| 190 |
tier: int,
|
| 191 |
history: dict[str, Any],
|
| 192 |
problem_id: str | None,
|
|
|
|
| 193 |
) -> random.Random:
|
|
|
|
|
|
|
|
|
|
| 194 |
seed_material = {
|
| 195 |
"tier": tier,
|
| 196 |
"problem_id": problem_id or "",
|
| 197 |
"pass_rates": [round(float(value), 4) for value in history.get("recent_pass_rates", [])[-8:]],
|
| 198 |
"problem_types": list(history.get("problem_types", [])[-8:]),
|
| 199 |
"episode_index": int(history.get("episode_index", 0)),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
}
|
| 201 |
digest = hashlib.sha256(repr(seed_material).encode("utf-8")).hexdigest()
|
| 202 |
return random.Random(int(digest[:16], 16))
|
|
@@ -246,7 +276,7 @@ def _build_templates() -> list[ProblemTemplate]:
|
|
| 246 |
title="Sum Even Numbers",
|
| 247 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 248 |
constraints="1 <= n <= 12; -100 <= values[i] <= 100",
|
| 249 |
-
statement_builder=lambda
|
| 250 |
"Given a list of integers, print the sum of the numbers that are even. "
|
| 251 |
"If no number is even, print 0."
|
| 252 |
),
|
|
@@ -259,19 +289,71 @@ def _build_templates() -> list[ProblemTemplate]:
|
|
| 259 |
title="Range Span",
|
| 260 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 261 |
constraints="2 <= n <= 12; -100 <= values[i] <= 100",
|
| 262 |
-
statement_builder=lambda
|
| 263 |
"Given a list of integers, print the difference between the maximum and minimum value."
|
| 264 |
),
|
| 265 |
solver=_solve_range_span,
|
| 266 |
case_builder=_build_range_span_cases,
|
| 267 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
ProblemTemplate(
|
| 269 |
problem_type="count_local_peaks",
|
| 270 |
difficulty_tier=2,
|
| 271 |
title="Count Local Peaks",
|
| 272 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 273 |
-
constraints="3 <= n <=
|
| 274 |
-
statement_builder=lambda
|
| 275 |
"Count how many indices i are local peaks, meaning values[i] is strictly greater than both "
|
| 276 |
"values[i-1] and values[i+1]. The first and last element can never be peaks."
|
| 277 |
),
|
|
@@ -283,20 +365,81 @@ def _build_templates() -> list[ProblemTemplate]:
|
|
| 283 |
difficulty_tier=2,
|
| 284 |
title="Longest Non-Decreasing Run",
|
| 285 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 286 |
-
constraints="1 <= n <=
|
| 287 |
-
statement_builder=lambda
|
| 288 |
"Find the length of the longest contiguous subarray whose values are non-decreasing."
|
| 289 |
),
|
| 290 |
solver=_solve_longest_non_decreasing_run,
|
| 291 |
case_builder=_build_run_cases,
|
| 292 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
ProblemTemplate(
|
| 294 |
problem_type="smallest_most_frequent",
|
| 295 |
difficulty_tier=3,
|
| 296 |
title="Smallest Most Frequent",
|
| 297 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 298 |
-
constraints="1 <= n <=
|
| 299 |
-
statement_builder=lambda
|
| 300 |
"Print the value that appears most often in the array. If several values have the same highest "
|
| 301 |
"frequency, print the smallest of them."
|
| 302 |
),
|
|
@@ -308,79 +451,343 @@ def _build_templates() -> list[ProblemTemplate]:
|
|
| 308 |
difficulty_tier=3,
|
| 309 |
title="Reverse Words",
|
| 310 |
input_format="A single line containing one or more words separated by spaces.",
|
| 311 |
-
constraints="1 <= line length <=
|
| 312 |
-
statement_builder=lambda
|
| 313 |
"Read a line of text and print the words in reverse order. Multiple spaces in the input should "
|
| 314 |
"be treated as a single separator."
|
| 315 |
),
|
| 316 |
solver=_solve_reverse_words,
|
| 317 |
case_builder=_build_reverse_word_cases,
|
| 318 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
|
| 322 |
-
def
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
vocabulary = [
|
| 385 |
"graph",
|
| 386 |
"queue",
|
|
@@ -395,21 +802,170 @@ def _build_reverse_word_cases(rng: random.Random, difficulty_scalar: float) -> l
|
|
| 395 |
"node",
|
| 396 |
"edge",
|
| 397 |
]
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
|
| 409 |
def _array_case(numbers: list[int]) -> str:
|
| 410 |
return f"{len(numbers)}\n{' '.join(str(number) for number in numbers)}\n"
|
| 411 |
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
def _solve_sum_even_numbers(stdin: str) -> str:
|
| 414 |
_, numbers = _parse_int_array(stdin)
|
| 415 |
return str(sum(number for number in numbers if number % 2 == 0))
|
|
@@ -420,6 +976,47 @@ def _solve_range_span(stdin: str) -> str:
|
|
| 420 |
return str(max(numbers) - min(numbers))
|
| 421 |
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
def _solve_count_local_peaks(stdin: str) -> str:
|
| 424 |
_, numbers = _parse_int_array(stdin)
|
| 425 |
peaks = 0
|
|
@@ -442,11 +1039,60 @@ def _solve_longest_non_decreasing_run(stdin: str) -> str:
|
|
| 442 |
return str(best)
|
| 443 |
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
def _solve_smallest_most_frequent(stdin: str) -> str:
|
| 446 |
_, numbers = _parse_int_array(stdin)
|
| 447 |
-
counts
|
| 448 |
-
for number in numbers:
|
| 449 |
-
counts[number] = counts.get(number, 0) + 1
|
| 450 |
best_count = max(counts.values())
|
| 451 |
best_value = min(number for number, count in counts.items() if count == best_count)
|
| 452 |
return str(best_value)
|
|
@@ -457,6 +1103,76 @@ def _solve_reverse_words(stdin: str) -> str:
|
|
| 457 |
return " ".join(reversed(words))
|
| 458 |
|
| 459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
|
| 461 |
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 462 |
n = int(lines[0])
|
|
@@ -464,3 +1180,146 @@ def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
|
|
| 464 |
if len(numbers) != n:
|
| 465 |
raise ValueError(f"Expected {n} integers, received {len(numbers)}")
|
| 466 |
return n, numbers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import hashlib
|
|
|
|
| 4 |
import random
|
| 5 |
+
from collections import Counter, deque
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import Any, Callable
|
| 8 |
|
| 9 |
+
VISIBLE_TEST_COUNT = 2
|
| 10 |
+
HIDDEN_TEST_COUNT = 8
|
| 11 |
+
TOTAL_TEST_CASES = VISIBLE_TEST_COUNT + HIDDEN_TEST_COUNT
|
| 12 |
+
MIN_TEST_CASES = TOTAL_TEST_CASES
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass(frozen=True)
|
|
|
|
| 19 |
title: str
|
| 20 |
input_format: str
|
| 21 |
constraints: str
|
| 22 |
+
statement_builder: Callable[[], str]
|
| 23 |
solver: Callable[[str], str]
|
| 24 |
+
case_builder: Callable[[random.Random], list[str]]
|
| 25 |
|
| 26 |
|
| 27 |
def generator_reward(
|
|
|
|
| 56 |
return False
|
| 57 |
|
| 58 |
test_cases = problem_dict.get("test_cases")
|
| 59 |
+
if not isinstance(test_cases, list) or len(test_cases) != TOTAL_TEST_CASES:
|
| 60 |
return False
|
| 61 |
|
| 62 |
seen_inputs: set[str] = set()
|
| 63 |
distinct_outputs: set[str] = set()
|
| 64 |
visible_count = 0
|
| 65 |
+
hidden_count = 0
|
| 66 |
|
| 67 |
+
for index, test_case in enumerate(test_cases):
|
| 68 |
if not isinstance(test_case, dict):
|
| 69 |
return False
|
| 70 |
|
|
|
|
| 79 |
return False
|
| 80 |
seen_inputs.add(raw_input)
|
| 81 |
distinct_outputs.add(raw_output.strip())
|
|
|
|
| 82 |
|
| 83 |
+
if index < VISIBLE_TEST_COUNT and not is_visible:
|
| 84 |
+
return False
|
| 85 |
+
if index >= VISIBLE_TEST_COUNT and is_visible:
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
if is_visible:
|
| 89 |
+
visible_count += 1
|
| 90 |
+
else:
|
| 91 |
+
hidden_count += 1
|
| 92 |
+
|
| 93 |
+
if visible_count != VISIBLE_TEST_COUNT or hidden_count != HIDDEN_TEST_COUNT:
|
| 94 |
return False
|
| 95 |
|
| 96 |
+
normalized_outputs = {output.strip().lower() for output in distinct_outputs}
|
| 97 |
+
min_output_diversity = 2 if normalized_outputs.issubset({"yes", "no", "true", "false", "0", "1"}) else max(
|
| 98 |
+
3,
|
| 99 |
+
len(test_cases) // 3,
|
| 100 |
+
)
|
| 101 |
+
if len(distinct_outputs) < min_output_diversity:
|
| 102 |
return False
|
| 103 |
|
| 104 |
return True
|
|
|
|
| 111 |
self.deterministic = deterministic
|
| 112 |
self.templates = _build_templates()
|
| 113 |
|
| 114 |
+
def generate_problem(
|
| 115 |
self,
|
| 116 |
difficulty_level: int | float | str,
|
| 117 |
history: dict[str, Any] | None,
|
| 118 |
problem_id: str | None = None,
|
| 119 |
+
family_weights: dict[str, float] | None = None,
|
| 120 |
) -> dict[str, Any]:
|
| 121 |
history = history or {}
|
| 122 |
target_tier = _difficulty_to_tier(difficulty_level)
|
| 123 |
+
rng = self._rng_for(target_tier, history, problem_id, family_weights or {})
|
| 124 |
+
template = self._choose_template(
|
| 125 |
+
target_tier,
|
| 126 |
+
history,
|
| 127 |
+
rng,
|
| 128 |
+
forced_problem_type=problem_id,
|
| 129 |
+
family_weights=family_weights or {},
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
for _ in range(20):
|
| 133 |
+
raw_cases = template.case_builder(rng)
|
|
|
|
| 134 |
test_cases = [
|
| 135 |
{
|
| 136 |
"input": case_input,
|
| 137 |
"output": template.solver(case_input),
|
| 138 |
+
"is_visible": index < VISIBLE_TEST_COUNT,
|
| 139 |
}
|
| 140 |
+
for index, case_input in enumerate(raw_cases)
|
| 141 |
]
|
| 142 |
signature = self._problem_signature(template.problem_type, test_cases)
|
| 143 |
problem = {
|
| 144 |
"problem_id": f"{template.problem_type}_{signature[:8]}",
|
| 145 |
"problem_type": template.problem_type,
|
| 146 |
+
"difficulty": round(self._tier_to_scalar(target_tier), 4),
|
| 147 |
+
"difficulty_label": DIFFICULTY_LABELS[target_tier],
|
| 148 |
+
"problem": template.statement_builder(),
|
| 149 |
"input_format": template.input_format,
|
| 150 |
"constraints": template.constraints,
|
| 151 |
"test_cases": test_cases,
|
| 152 |
"visible_problem": {
|
| 153 |
+
"problem": template.statement_builder(),
|
| 154 |
"input_format": template.input_format,
|
| 155 |
"constraints": template.constraints,
|
| 156 |
},
|
|
|
|
| 162 |
|
| 163 |
raise ValueError(f"Unable to generate a valid problem for template {template.problem_type}")
|
| 164 |
|
| 165 |
+
def generate(
|
| 166 |
+
self,
|
| 167 |
+
difficulty_level: int | float | str,
|
| 168 |
+
history: dict[str, Any] | None,
|
| 169 |
+
problem_id: str | None = None,
|
| 170 |
+
family_weights: dict[str, float] | None = None,
|
| 171 |
+
) -> dict[str, Any]:
|
| 172 |
+
return self.generate_problem(
|
| 173 |
+
difficulty_level=difficulty_level,
|
| 174 |
+
history=history,
|
| 175 |
+
problem_id=problem_id,
|
| 176 |
+
family_weights=family_weights,
|
| 177 |
+
)
|
| 178 |
|
| 179 |
def _choose_template(
|
| 180 |
self,
|
|
|
|
| 182 |
history: dict[str, Any],
|
| 183 |
rng: random.Random,
|
| 184 |
forced_problem_type: str | None = None,
|
| 185 |
+
family_weights: dict[str, float] | None = None,
|
| 186 |
) -> ProblemTemplate:
|
| 187 |
eligible = [template for template in self.templates if template.difficulty_tier == tier]
|
| 188 |
if not eligible:
|
|
|
|
| 196 |
if template.problem_type == forced_problem_type:
|
| 197 |
return template
|
| 198 |
|
| 199 |
+
recent_types = list(history.get("problem_types", [])[-6:])
|
| 200 |
+
weights: list[float] = []
|
| 201 |
for template in eligible:
|
| 202 |
+
base_weight = float((family_weights or {}).get(template.problem_type, 1.0))
|
| 203 |
+
base_weight = max(base_weight, 1e-6)
|
| 204 |
+
if template.problem_type in recent_types:
|
| 205 |
+
base_weight *= 0.35
|
| 206 |
+
weights.append(base_weight)
|
| 207 |
+
|
| 208 |
+
return rng.choices(eligible, weights=weights, k=1)[0]
|
| 209 |
|
| 210 |
def _rng_for(
|
| 211 |
self,
|
| 212 |
tier: int,
|
| 213 |
history: dict[str, Any],
|
| 214 |
problem_id: str | None,
|
| 215 |
+
family_weights: dict[str, float],
|
| 216 |
) -> random.Random:
|
| 217 |
+
if not self.deterministic:
|
| 218 |
+
return random.Random()
|
| 219 |
+
|
| 220 |
seed_material = {
|
| 221 |
"tier": tier,
|
| 222 |
"problem_id": problem_id or "",
|
| 223 |
"pass_rates": [round(float(value), 4) for value in history.get("recent_pass_rates", [])[-8:]],
|
| 224 |
"problem_types": list(history.get("problem_types", [])[-8:]),
|
| 225 |
"episode_index": int(history.get("episode_index", 0)),
|
| 226 |
+
"family_weights": {
|
| 227 |
+
key: round(float(value), 4)
|
| 228 |
+
for key, value in sorted(family_weights.items())
|
| 229 |
+
},
|
| 230 |
}
|
| 231 |
digest = hashlib.sha256(repr(seed_material).encode("utf-8")).hexdigest()
|
| 232 |
return random.Random(int(digest[:16], 16))
|
|
|
|
| 276 |
title="Sum Even Numbers",
|
| 277 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 278 |
constraints="1 <= n <= 12; -100 <= values[i] <= 100",
|
| 279 |
+
statement_builder=lambda: (
|
| 280 |
"Given a list of integers, print the sum of the numbers that are even. "
|
| 281 |
"If no number is even, print 0."
|
| 282 |
),
|
|
|
|
| 289 |
title="Range Span",
|
| 290 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 291 |
constraints="2 <= n <= 12; -100 <= values[i] <= 100",
|
| 292 |
+
statement_builder=lambda: (
|
| 293 |
"Given a list of integers, print the difference between the maximum and minimum value."
|
| 294 |
),
|
| 295 |
solver=_solve_range_span,
|
| 296 |
case_builder=_build_range_span_cases,
|
| 297 |
),
|
| 298 |
+
ProblemTemplate(
|
| 299 |
+
problem_type="count_vowels",
|
| 300 |
+
difficulty_tier=1,
|
| 301 |
+
title="Count Vowels",
|
| 302 |
+
input_format="A single line containing lowercase or uppercase letters and spaces.",
|
| 303 |
+
constraints="1 <= line length <= 80",
|
| 304 |
+
statement_builder=lambda: (
|
| 305 |
+
"Count how many vowels appear in the input line. Treat a, e, i, o, u as vowels "
|
| 306 |
+
"and ignore case."
|
| 307 |
+
),
|
| 308 |
+
solver=_solve_count_vowels,
|
| 309 |
+
case_builder=_build_count_vowels_cases,
|
| 310 |
+
),
|
| 311 |
+
ProblemTemplate(
|
| 312 |
+
problem_type="max_consecutive_ones",
|
| 313 |
+
difficulty_tier=1,
|
| 314 |
+
title="Max Consecutive Ones",
|
| 315 |
+
input_format="A single line containing a binary string.",
|
| 316 |
+
constraints="1 <= string length <= 40",
|
| 317 |
+
statement_builder=lambda: (
|
| 318 |
+
"Print the length of the longest contiguous block of '1' characters in the binary string."
|
| 319 |
+
),
|
| 320 |
+
solver=_solve_max_consecutive_ones,
|
| 321 |
+
case_builder=_build_max_consecutive_ones_cases,
|
| 322 |
+
),
|
| 323 |
+
ProblemTemplate(
|
| 324 |
+
problem_type="fizzbuzz_variant",
|
| 325 |
+
difficulty_tier=1,
|
| 326 |
+
title="FizzBuzz Variant",
|
| 327 |
+
input_format="The first line contains n a b. The second line contains label_a and label_b.",
|
| 328 |
+
constraints="1 <= n <= 25; 2 <= a, b <= 9; labels contain only letters",
|
| 329 |
+
statement_builder=lambda: (
|
| 330 |
+
"For each integer from 1 to n, print label_a if the number is divisible by a, "
|
| 331 |
+
"label_b if it is divisible by b, and the concatenation label_a+label_b if it is divisible "
|
| 332 |
+
"by both. Otherwise print the number itself. Output all tokens on one line separated by spaces."
|
| 333 |
+
),
|
| 334 |
+
solver=_solve_fizzbuzz_variant,
|
| 335 |
+
case_builder=_build_fizzbuzz_variant_cases,
|
| 336 |
+
),
|
| 337 |
+
ProblemTemplate(
|
| 338 |
+
problem_type="running_total",
|
| 339 |
+
difficulty_tier=1,
|
| 340 |
+
title="Running Total",
|
| 341 |
+
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 342 |
+
constraints="1 <= n <= 14; -50 <= values[i] <= 50",
|
| 343 |
+
statement_builder=lambda: (
|
| 344 |
+
"Print the running total after each element of the array. Output the cumulative sums on one line "
|
| 345 |
+
"separated by spaces."
|
| 346 |
+
),
|
| 347 |
+
solver=_solve_running_total,
|
| 348 |
+
case_builder=_build_running_total_cases,
|
| 349 |
+
),
|
| 350 |
ProblemTemplate(
|
| 351 |
problem_type="count_local_peaks",
|
| 352 |
difficulty_tier=2,
|
| 353 |
title="Count Local Peaks",
|
| 354 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 355 |
+
constraints="3 <= n <= 16; -100 <= values[i] <= 100",
|
| 356 |
+
statement_builder=lambda: (
|
| 357 |
"Count how many indices i are local peaks, meaning values[i] is strictly greater than both "
|
| 358 |
"values[i-1] and values[i+1]. The first and last element can never be peaks."
|
| 359 |
),
|
|
|
|
| 365 |
difficulty_tier=2,
|
| 366 |
title="Longest Non-Decreasing Run",
|
| 367 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 368 |
+
constraints="1 <= n <= 18; -100 <= values[i] <= 100",
|
| 369 |
+
statement_builder=lambda: (
|
| 370 |
"Find the length of the longest contiguous subarray whose values are non-decreasing."
|
| 371 |
),
|
| 372 |
solver=_solve_longest_non_decreasing_run,
|
| 373 |
case_builder=_build_run_cases,
|
| 374 |
),
|
| 375 |
+
ProblemTemplate(
|
| 376 |
+
problem_type="two_sum_count",
|
| 377 |
+
difficulty_tier=2,
|
| 378 |
+
title="Two Sum Count",
|
| 379 |
+
input_format="The first line contains n and target. The second line contains n space-separated integers.",
|
| 380 |
+
constraints="2 <= n <= 16; -50 <= values[i] <= 50",
|
| 381 |
+
statement_builder=lambda: (
|
| 382 |
+
"Count how many index pairs (i, j) with i < j have values[i] + values[j] equal to target."
|
| 383 |
+
),
|
| 384 |
+
solver=_solve_two_sum_count,
|
| 385 |
+
case_builder=_build_two_sum_count_cases,
|
| 386 |
+
),
|
| 387 |
+
ProblemTemplate(
|
| 388 |
+
problem_type="max_subarray_sum",
|
| 389 |
+
difficulty_tier=2,
|
| 390 |
+
title="Maximum Subarray Sum",
|
| 391 |
+
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 392 |
+
constraints="1 <= n <= 18; -50 <= values[i] <= 50",
|
| 393 |
+
statement_builder=lambda: (
|
| 394 |
+
"Print the maximum possible sum of a contiguous subarray."
|
| 395 |
+
),
|
| 396 |
+
solver=_solve_max_subarray_sum,
|
| 397 |
+
case_builder=_build_max_subarray_sum_cases,
|
| 398 |
+
),
|
| 399 |
+
ProblemTemplate(
|
| 400 |
+
problem_type="group_anagrams_count",
|
| 401 |
+
difficulty_tier=2,
|
| 402 |
+
title="Group Anagrams Count",
|
| 403 |
+
input_format="The first line contains n. The second line contains n space-separated lowercase words.",
|
| 404 |
+
constraints="1 <= n <= 12; each word length is between 1 and 8",
|
| 405 |
+
statement_builder=lambda: (
|
| 406 |
+
"Group words that are anagrams of each other. Print the number of distinct anagram groups."
|
| 407 |
+
),
|
| 408 |
+
solver=_solve_group_anagrams_count,
|
| 409 |
+
case_builder=_build_group_anagrams_cases,
|
| 410 |
+
),
|
| 411 |
+
ProblemTemplate(
|
| 412 |
+
problem_type="balanced_brackets",
|
| 413 |
+
difficulty_tier=2,
|
| 414 |
+
title="Balanced Brackets",
|
| 415 |
+
input_format="A single line containing only the characters ()[]{}.",
|
| 416 |
+
constraints="1 <= line length <= 50",
|
| 417 |
+
statement_builder=lambda: (
|
| 418 |
+
"Print YES if the bracket string is balanced and NO otherwise."
|
| 419 |
+
),
|
| 420 |
+
solver=_solve_balanced_brackets,
|
| 421 |
+
case_builder=_build_balanced_brackets_cases,
|
| 422 |
+
),
|
| 423 |
+
ProblemTemplate(
|
| 424 |
+
problem_type="matrix_diagonal_sum",
|
| 425 |
+
difficulty_tier=2,
|
| 426 |
+
title="Matrix Diagonal Sum",
|
| 427 |
+
input_format="The first line contains n. The next n lines each contain n space-separated integers.",
|
| 428 |
+
constraints="2 <= n <= 6; -20 <= matrix[i][j] <= 20",
|
| 429 |
+
statement_builder=lambda: (
|
| 430 |
+
"For the square matrix, print the sum of the primary diagonal and secondary diagonal. "
|
| 431 |
+
"If n is odd, count the center element only once."
|
| 432 |
+
),
|
| 433 |
+
solver=_solve_matrix_diagonal_sum,
|
| 434 |
+
case_builder=_build_matrix_diagonal_sum_cases,
|
| 435 |
+
),
|
| 436 |
ProblemTemplate(
|
| 437 |
problem_type="smallest_most_frequent",
|
| 438 |
difficulty_tier=3,
|
| 439 |
title="Smallest Most Frequent",
|
| 440 |
input_format="The first line contains n. The second line contains n space-separated integers.",
|
| 441 |
+
constraints="1 <= n <= 20; -30 <= values[i] <= 30",
|
| 442 |
+
statement_builder=lambda: (
|
| 443 |
"Print the value that appears most often in the array. If several values have the same highest "
|
| 444 |
"frequency, print the smallest of them."
|
| 445 |
),
|
|
|
|
| 451 |
difficulty_tier=3,
|
| 452 |
title="Reverse Words",
|
| 453 |
input_format="A single line containing one or more words separated by spaces.",
|
| 454 |
+
constraints="1 <= line length <= 120",
|
| 455 |
+
statement_builder=lambda: (
|
| 456 |
"Read a line of text and print the words in reverse order. Multiple spaces in the input should "
|
| 457 |
"be treated as a single separator."
|
| 458 |
),
|
| 459 |
solver=_solve_reverse_words,
|
| 460 |
case_builder=_build_reverse_word_cases,
|
| 461 |
),
|
| 462 |
+
ProblemTemplate(
|
| 463 |
+
problem_type="longest_common_subsequence",
|
| 464 |
+
difficulty_tier=3,
|
| 465 |
+
title="Longest Common Subsequence",
|
| 466 |
+
input_format="The first line contains string s. The second line contains string t.",
|
| 467 |
+
constraints="1 <= len(s), len(t) <= 18; strings contain lowercase letters",
|
| 468 |
+
statement_builder=lambda: (
|
| 469 |
+
"Print the length of the longest common subsequence of the two strings."
|
| 470 |
+
),
|
| 471 |
+
solver=_solve_longest_common_subsequence,
|
| 472 |
+
case_builder=_build_lcs_cases,
|
| 473 |
+
),
|
| 474 |
+
ProblemTemplate(
|
| 475 |
+
problem_type="word_ladder_steps",
|
| 476 |
+
difficulty_tier=3,
|
| 477 |
+
title="Word Ladder Steps",
|
| 478 |
+
input_format="The first line contains start and target. The second line contains n. The third line contains n space-separated words.",
|
| 479 |
+
constraints="All words have the same length between 3 and 5; 1 <= n <= 14",
|
| 480 |
+
statement_builder=lambda: (
|
| 481 |
+
"You may change one character at a time. Every intermediate word and the target word must appear "
|
| 482 |
+
"in the given word list. Print the minimum number of single-character changes needed to transform "
|
| 483 |
+
"start into target, or -1 if it is impossible."
|
| 484 |
+
),
|
| 485 |
+
solver=_solve_word_ladder_steps,
|
| 486 |
+
case_builder=_build_word_ladder_cases,
|
| 487 |
+
),
|
| 488 |
+
ProblemTemplate(
|
| 489 |
+
problem_type="merge_intervals",
|
| 490 |
+
difficulty_tier=3,
|
| 491 |
+
title="Merge Intervals",
|
| 492 |
+
input_format="The first line contains n. The next n lines each contain start and end.",
|
| 493 |
+
constraints="1 <= n <= 12; -20 <= start <= end <= 30",
|
| 494 |
+
statement_builder=lambda: (
|
| 495 |
+
"Merge all overlapping intervals and print how many intervals remain after merging."
|
| 496 |
+
),
|
| 497 |
+
solver=_solve_merge_intervals,
|
| 498 |
+
case_builder=_build_merge_intervals_cases,
|
| 499 |
+
),
|
| 500 |
+
ProblemTemplate(
|
| 501 |
+
problem_type="min_coins",
|
| 502 |
+
difficulty_tier=3,
|
| 503 |
+
title="Minimum Coins",
|
| 504 |
+
input_format="The first line contains n and target. The second line contains n distinct positive coin values.",
|
| 505 |
+
constraints="1 <= n <= 8; 1 <= target <= 40; 1 <= coin values <= 20",
|
| 506 |
+
statement_builder=lambda: (
|
| 507 |
+
"Print the minimum number of coins needed to make exactly target using unlimited copies of the given "
|
| 508 |
+
"coin values. Print -1 if it is impossible."
|
| 509 |
+
),
|
| 510 |
+
solver=_solve_min_coins,
|
| 511 |
+
case_builder=_build_min_coins_cases,
|
| 512 |
+
),
|
| 513 |
+
ProblemTemplate(
|
| 514 |
+
problem_type="rotate_matrix_90",
|
| 515 |
+
difficulty_tier=3,
|
| 516 |
+
title="Rotate Matrix 90 Degrees",
|
| 517 |
+
input_format="The first line contains n. The next n lines each contain n space-separated integers.",
|
| 518 |
+
constraints="2 <= n <= 5; -20 <= matrix[i][j] <= 20",
|
| 519 |
+
statement_builder=lambda: (
|
| 520 |
+
"Rotate the square matrix 90 degrees clockwise and print the rotated matrix flattened in row-major "
|
| 521 |
+
"order on one line separated by spaces."
|
| 522 |
+
),
|
| 523 |
+
solver=_solve_rotate_matrix_90,
|
| 524 |
+
case_builder=_build_rotate_matrix_cases,
|
| 525 |
+
),
|
| 526 |
+
]
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _build_sum_even_cases(rng: random.Random) -> list[str]:
|
| 530 |
+
visible_pool = [
|
| 531 |
+
_array_case([2, 3, 4]),
|
| 532 |
+
_array_case([1, 3, 5, 7]),
|
| 533 |
+
_array_case([0, -2, 5, 8]),
|
| 534 |
+
]
|
| 535 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _sum_even_hidden_case)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _sum_even_hidden_case(rng: random.Random) -> str:
|
| 539 |
+
length = rng.randint(5, 12)
|
| 540 |
+
numbers = [rng.randint(-50, 50) for _ in range(length)]
|
| 541 |
+
if all(number % 2 for number in numbers):
|
| 542 |
+
numbers[rng.randrange(length)] = rng.choice([-8, -2, 0, 6, 14])
|
| 543 |
+
return _array_case(numbers)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def _build_range_span_cases(rng: random.Random) -> list[str]:
|
| 547 |
+
visible_pool = [
|
| 548 |
+
_array_case([1, 4, 9]),
|
| 549 |
+
_array_case([-2, -2, -2, 1]),
|
| 550 |
+
_array_case([8, 3]),
|
| 551 |
+
]
|
| 552 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _range_span_hidden_case)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _range_span_hidden_case(rng: random.Random) -> str:
|
| 556 |
+
length = rng.randint(4, 12)
|
| 557 |
+
numbers = [rng.randint(-60, 60) for _ in range(length)]
|
| 558 |
+
if len(set(numbers)) == 1:
|
| 559 |
+
numbers[-1] += 5
|
| 560 |
+
return _array_case(numbers)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def _build_count_vowels_cases(rng: random.Random) -> list[str]:
|
| 564 |
+
visible_pool = [
|
| 565 |
+
"hello world\n",
|
| 566 |
+
"sky\n",
|
| 567 |
+
"AEIOU\n",
|
| 568 |
+
]
|
| 569 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _count_vowels_hidden_case)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def _count_vowels_hidden_case(rng: random.Random) -> str:
|
| 573 |
+
word_bank = [
|
| 574 |
+
"algorithm",
|
| 575 |
+
"queue",
|
| 576 |
+
"stack",
|
| 577 |
+
"binary",
|
| 578 |
+
"graph",
|
| 579 |
+
"open env",
|
| 580 |
+
"unit test",
|
| 581 |
+
"dynamic programming",
|
| 582 |
+
"vowel heavy area",
|
| 583 |
+
"crypt rhythm",
|
| 584 |
+
]
|
| 585 |
+
parts = [rng.choice(word_bank) for _ in range(rng.randint(1, 3))]
|
| 586 |
+
text = " ".join(parts)
|
| 587 |
+
return f"{text[:80]}\n"
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def _build_max_consecutive_ones_cases(rng: random.Random) -> list[str]:
|
| 591 |
+
visible_pool = ["1101110\n", "00000\n", "1\n"]
|
| 592 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _max_consecutive_ones_hidden_case)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _max_consecutive_ones_hidden_case(rng: random.Random) -> str:
|
| 596 |
+
length = rng.randint(8, 40)
|
| 597 |
+
chars = [rng.choice(["0", "1"]) for _ in range(length)]
|
| 598 |
+
if "1" not in chars:
|
| 599 |
+
start = rng.randint(0, max(0, length - 3))
|
| 600 |
+
run_length = rng.randint(1, min(5, length - start))
|
| 601 |
+
for index in range(start, start + run_length):
|
| 602 |
+
chars[index] = "1"
|
| 603 |
+
return f"{''.join(chars)}\n"
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _build_fizzbuzz_variant_cases(rng: random.Random) -> list[str]:
|
| 607 |
+
visible_pool = [
|
| 608 |
+
_fizzbuzz_case(8, 3, 5, "Fizz", "Buzz"),
|
| 609 |
+
_fizzbuzz_case(6, 2, 4, "Hop", "Pop"),
|
| 610 |
+
_fizzbuzz_case(10, 2, 3, "Up", "Go"),
|
| 611 |
+
]
|
| 612 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _fizzbuzz_hidden_case)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _fizzbuzz_hidden_case(rng: random.Random) -> str:
|
| 616 |
+
labels = [
|
| 617 |
+
("Fizz", "Buzz"),
|
| 618 |
+
("Ping", "Pong"),
|
| 619 |
+
("Hop", "Skip"),
|
| 620 |
+
("Alpha", "Beta"),
|
| 621 |
+
("Red", "Blue"),
|
| 622 |
+
]
|
| 623 |
+
label_a, label_b = rng.choice(labels)
|
| 624 |
+
a = rng.randint(2, 6)
|
| 625 |
+
b = rng.randint(2, 6)
|
| 626 |
+
while b == a:
|
| 627 |
+
b = rng.randint(2, 6)
|
| 628 |
+
n = rng.randint(10, 25)
|
| 629 |
+
return _fizzbuzz_case(n, a, b, label_a, label_b)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def _build_running_total_cases(rng: random.Random) -> list[str]:
|
| 633 |
+
visible_pool = [
|
| 634 |
+
_array_case([1, 2, 3, 4]),
|
| 635 |
+
_array_case([5, -2, 7]),
|
| 636 |
+
_array_case([0, 0, 1]),
|
| 637 |
+
]
|
| 638 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _running_total_hidden_case)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def _running_total_hidden_case(rng: random.Random) -> str:
|
| 642 |
+
length = rng.randint(5, 14)
|
| 643 |
+
numbers = [rng.randint(-20, 20) for _ in range(length)]
|
| 644 |
+
return _array_case(numbers)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def _build_peak_cases(rng: random.Random) -> list[str]:
|
| 648 |
+
visible_pool = [
|
| 649 |
+
_array_case([1, 3, 2, 4, 1]),
|
| 650 |
+
_array_case([5, 4, 3, 2, 1]),
|
| 651 |
+
_array_case([2, 5, 1, 5, 2]),
|
| 652 |
+
]
|
| 653 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _peak_hidden_case)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def _peak_hidden_case(rng: random.Random) -> str:
|
| 657 |
+
length = rng.randint(6, 16)
|
| 658 |
+
numbers = [rng.randint(-20, 20)]
|
| 659 |
+
for _ in range(length - 1):
|
| 660 |
+
numbers.append(numbers[-1] + rng.randint(-8, 8))
|
| 661 |
+
return _array_case(numbers)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def _build_run_cases(rng: random.Random) -> list[str]:
|
| 665 |
+
visible_pool = [
|
| 666 |
+
_array_case([1, 2, 2, 1, 3]),
|
| 667 |
+
_array_case([5, 4, 3, 2]),
|
| 668 |
+
_array_case([1, 1, 1, 1]),
|
| 669 |
+
]
|
| 670 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _run_hidden_case)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def _run_hidden_case(rng: random.Random) -> str:
|
| 674 |
+
length = rng.randint(6, 18)
|
| 675 |
+
numbers = [rng.randint(-20, 20)]
|
| 676 |
+
for _ in range(length - 1):
|
| 677 |
+
numbers.append(numbers[-1] + rng.randint(-6, 6))
|
| 678 |
+
return _array_case(numbers)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def _build_two_sum_count_cases(rng: random.Random) -> list[str]:
|
| 682 |
+
visible_pool = [
|
| 683 |
+
_target_array_case(5, [1, 2, 3, 4]),
|
| 684 |
+
_target_array_case(2, [1, 1, 1, 1]),
|
| 685 |
+
_target_array_case(0, [-1, 1, 2, -2]),
|
| 686 |
+
]
|
| 687 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _two_sum_hidden_case)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _two_sum_hidden_case(rng: random.Random) -> str:
|
| 691 |
+
length = rng.randint(5, 16)
|
| 692 |
+
numbers = [rng.randint(-12, 12) for _ in range(length)]
|
| 693 |
+
target = rng.randint(-10, 10)
|
| 694 |
+
return _target_array_case(target, numbers)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _build_max_subarray_sum_cases(rng: random.Random) -> list[str]:
|
| 698 |
+
visible_pool = [
|
| 699 |
+
_array_case([1, -2, 3, 4, -1]),
|
| 700 |
+
_array_case([-5, -1, -8]),
|
| 701 |
+
_array_case([2, -1, 2, 3, 4, -5]),
|
| 702 |
]
|
| 703 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _max_subarray_hidden_case)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def _max_subarray_hidden_case(rng: random.Random) -> str:
|
| 707 |
+
length = rng.randint(6, 18)
|
| 708 |
+
numbers = [rng.randint(-20, 20) for _ in range(length)]
|
| 709 |
+
return _array_case(numbers)
|
| 710 |
|
| 711 |
|
| 712 |
+
def _build_group_anagrams_cases(rng: random.Random) -> list[str]:
|
| 713 |
+
visible_pool = [
|
| 714 |
+
_word_list_case(["eat", "tea", "tan", "ate", "nat", "bat"]),
|
| 715 |
+
_word_list_case(["abc", "bca", "cab", "foo"]),
|
| 716 |
+
_word_list_case(["a", "b", "ab", "ba"]),
|
| 717 |
+
]
|
| 718 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _group_anagrams_hidden_case)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def _group_anagrams_hidden_case(rng: random.Random) -> str:
|
| 722 |
+
base_words = ["stone", "tones", "notes", "silent", "listen", "enlist", "rat", "tar", "art"]
|
| 723 |
+
words: list[str] = []
|
| 724 |
+
for _ in range(rng.randint(4, 10)):
|
| 725 |
+
word = rng.choice(base_words)
|
| 726 |
+
if rng.random() < 0.4:
|
| 727 |
+
shuffled = list(word)
|
| 728 |
+
rng.shuffle(shuffled)
|
| 729 |
+
words.append("".join(shuffled))
|
| 730 |
+
else:
|
| 731 |
+
words.append(word)
|
| 732 |
+
return _word_list_case(words)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def _build_balanced_brackets_cases(rng: random.Random) -> list[str]:
|
| 736 |
+
visible_pool = [
|
| 737 |
+
"([]{})\n",
|
| 738 |
+
"([)]\n",
|
| 739 |
+
"{[()]}\n",
|
| 740 |
+
]
|
| 741 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _balanced_brackets_hidden_case)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def _balanced_brackets_hidden_case(rng: random.Random) -> str:
|
| 745 |
+
if rng.random() < 0.5:
|
| 746 |
+
return f"{_make_balanced_brackets(rng, rng.randint(3, 10))}\n"
|
| 747 |
+
return f"{_make_unbalanced_brackets(rng, rng.randint(3, 10))}\n"
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def _build_matrix_diagonal_sum_cases(rng: random.Random) -> list[str]:
|
| 751 |
+
visible_pool = [
|
| 752 |
+
_matrix_case([[1, 2], [3, 4]]),
|
| 753 |
+
_matrix_case([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
|
| 754 |
+
_matrix_case([[2, 0, 2], [1, 5, 1], [2, 0, 2]]),
|
| 755 |
+
]
|
| 756 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _matrix_diagonal_hidden_case)
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def _matrix_diagonal_hidden_case(rng: random.Random) -> str:
|
| 760 |
+
size = rng.randint(3, 6)
|
| 761 |
+
matrix = [[rng.randint(-9, 9) for _ in range(size)] for _ in range(size)]
|
| 762 |
+
return _matrix_case(matrix)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _build_frequency_cases(rng: random.Random) -> list[str]:
|
| 766 |
+
visible_pool = [
|
| 767 |
+
_array_case([1, 2, 2, 3, 3, 3]),
|
| 768 |
+
_array_case([4, 4, 1, 1]),
|
| 769 |
+
_array_case([-1, -1, -2, -2, -2, 3]),
|
| 770 |
+
]
|
| 771 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _frequency_hidden_case)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def _frequency_hidden_case(rng: random.Random) -> str:
|
| 775 |
+
length = rng.randint(8, 20)
|
| 776 |
+
numbers = [rng.randint(-8, 8) for _ in range(length)]
|
| 777 |
+
numbers.extend([rng.choice(numbers), rng.choice(numbers)])
|
| 778 |
+
return _array_case(numbers)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def _build_reverse_word_cases(rng: random.Random) -> list[str]:
|
| 782 |
+
visible_pool = [
|
| 783 |
+
"hello world here\n",
|
| 784 |
+
" graph search tree \n",
|
| 785 |
+
"one\n",
|
| 786 |
+
]
|
| 787 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _reverse_words_hidden_case)
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def _reverse_words_hidden_case(rng: random.Random) -> str:
|
| 791 |
vocabulary = [
|
| 792 |
"graph",
|
| 793 |
"queue",
|
|
|
|
| 802 |
"node",
|
| 803 |
"edge",
|
| 804 |
]
|
| 805 |
+
words = [rng.choice(vocabulary) for _ in range(rng.randint(4, 9))]
|
| 806 |
+
spacer = " " * rng.randint(1, 3)
|
| 807 |
+
prefix = " " * rng.randint(0, 2)
|
| 808 |
+
suffix = " " * rng.randint(0, 2)
|
| 809 |
+
return f"{prefix}{spacer.join(words)}{suffix}\n"
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def _build_lcs_cases(rng: random.Random) -> list[str]:
|
| 813 |
+
visible_pool = [
|
| 814 |
+
_two_line_case("abcde", "ace"),
|
| 815 |
+
_two_line_case("abc", "abc"),
|
| 816 |
+
_two_line_case("abc", "def"),
|
| 817 |
+
]
|
| 818 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _lcs_hidden_case)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def _lcs_hidden_case(rng: random.Random) -> str:
|
| 822 |
+
alphabet = "abcdxyz"
|
| 823 |
+
left = "".join(rng.choice(alphabet) for _ in range(rng.randint(6, 14)))
|
| 824 |
+
right = "".join(rng.choice(alphabet) for _ in range(rng.randint(6, 14)))
|
| 825 |
+
return _two_line_case(left, right)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def _build_word_ladder_cases(rng: random.Random) -> list[str]:
|
| 829 |
+
visible_pool = [
|
| 830 |
+
_word_ladder_case("hit", "cog", ["hot", "dot", "dog", "lot", "log", "cog"]),
|
| 831 |
+
_word_ladder_case("same", "same", ["same", "lame", "came"]),
|
| 832 |
+
_word_ladder_case("cold", "warm", ["cord", "card", "ward", "sold"]),
|
| 833 |
+
]
|
| 834 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _word_ladder_hidden_case)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def _word_ladder_hidden_case(rng: random.Random) -> str:
|
| 838 |
+
length = rng.randint(3, 5)
|
| 839 |
+
if rng.random() < 0.7:
|
| 840 |
+
path_length = rng.randint(2, 5)
|
| 841 |
+
path = _build_word_ladder_path(rng, length, path_length)
|
| 842 |
+
extras = _build_word_ladder_extras(rng, length, rng.randint(2, 7), set(path))
|
| 843 |
+
words = path[1:] + extras
|
| 844 |
+
rng.shuffle(words)
|
| 845 |
+
return _word_ladder_case(path[0], path[-1], words)
|
| 846 |
+
|
| 847 |
+
start = _random_word(rng, length)
|
| 848 |
+
target = _random_word(rng, length)
|
| 849 |
+
while target == start:
|
| 850 |
+
target = _random_word(rng, length)
|
| 851 |
+
extras = _build_word_ladder_extras(rng, length, rng.randint(4, 10), {start, target})
|
| 852 |
+
return _word_ladder_case(start, target, extras)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def _build_merge_intervals_cases(rng: random.Random) -> list[str]:
|
| 856 |
+
visible_pool = [
|
| 857 |
+
_interval_case([(1, 3), (2, 4), (6, 8)]),
|
| 858 |
+
_interval_case([(1, 2), (3, 4), (5, 6)]),
|
| 859 |
+
_interval_case([(0, 5), (2, 3), (4, 10)]),
|
| 860 |
+
]
|
| 861 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _merge_intervals_hidden_case)
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def _merge_intervals_hidden_case(rng: random.Random) -> str:
|
| 865 |
+
intervals = []
|
| 866 |
+
for _ in range(rng.randint(4, 12)):
|
| 867 |
+
start = rng.randint(-10, 20)
|
| 868 |
+
end = start + rng.randint(0, 8)
|
| 869 |
+
intervals.append((start, end))
|
| 870 |
+
return _interval_case(intervals)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def _build_min_coins_cases(rng: random.Random) -> list[str]:
|
| 874 |
+
visible_pool = [
|
| 875 |
+
_coin_case([1, 3, 4], 6),
|
| 876 |
+
_coin_case([2, 5], 3),
|
| 877 |
+
_coin_case([2, 5, 7], 14),
|
| 878 |
+
]
|
| 879 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _min_coins_hidden_case)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def _min_coins_hidden_case(rng: random.Random) -> str:
|
| 883 |
+
coin_count = rng.randint(2, 6)
|
| 884 |
+
coins = sorted({rng.randint(1, 10) for _ in range(coin_count + 2)})
|
| 885 |
+
coins = coins[:coin_count]
|
| 886 |
+
target = rng.randint(5, 40)
|
| 887 |
+
return _coin_case(coins, target)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def _build_rotate_matrix_cases(rng: random.Random) -> list[str]:
|
| 891 |
+
visible_pool = [
|
| 892 |
+
_matrix_case([[1, 2], [3, 4]]),
|
| 893 |
+
_matrix_case([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
|
| 894 |
+
_matrix_case([[5, 1], [0, -1]]),
|
| 895 |
+
]
|
| 896 |
+
return _cases_from_pool_and_factory(rng, visible_pool, _rotate_matrix_hidden_case)
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def _rotate_matrix_hidden_case(rng: random.Random) -> str:
|
| 900 |
+
size = rng.randint(2, 5)
|
| 901 |
+
matrix = [[rng.randint(-9, 9) for _ in range(size)] for _ in range(size)]
|
| 902 |
+
return _matrix_case(matrix)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def _cases_from_pool_and_factory(
|
| 906 |
+
rng: random.Random,
|
| 907 |
+
visible_pool: list[str],
|
| 908 |
+
hidden_factory: Callable[[random.Random], str],
|
| 909 |
+
) -> list[str]:
|
| 910 |
+
cases: list[str] = []
|
| 911 |
+
seen: set[str] = set()
|
| 912 |
+
|
| 913 |
+
for case_input in rng.sample(visible_pool, k=VISIBLE_TEST_COUNT):
|
| 914 |
+
cases.append(case_input)
|
| 915 |
+
seen.add(case_input)
|
| 916 |
+
|
| 917 |
+
attempts = 0
|
| 918 |
+
while len(cases) < TOTAL_TEST_CASES:
|
| 919 |
+
candidate = hidden_factory(rng)
|
| 920 |
+
attempts += 1
|
| 921 |
+
if candidate in seen:
|
| 922 |
+
if attempts > 200:
|
| 923 |
+
raise ValueError("Unable to generate unique test cases.")
|
| 924 |
+
continue
|
| 925 |
+
seen.add(candidate)
|
| 926 |
+
cases.append(candidate)
|
| 927 |
+
|
| 928 |
+
return cases
|
| 929 |
|
| 930 |
|
| 931 |
def _array_case(numbers: list[int]) -> str:
|
| 932 |
return f"{len(numbers)}\n{' '.join(str(number) for number in numbers)}\n"
|
| 933 |
|
| 934 |
|
| 935 |
+
def _target_array_case(target: int, numbers: list[int]) -> str:
|
| 936 |
+
return f"{len(numbers)} {target}\n{' '.join(str(number) for number in numbers)}\n"
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def _word_list_case(words: list[str]) -> str:
|
| 940 |
+
return f"{len(words)}\n{' '.join(words)}\n"
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def _matrix_case(matrix: list[list[int]]) -> str:
|
| 944 |
+
rows = [" ".join(str(value) for value in row) for row in matrix]
|
| 945 |
+
return f"{len(matrix)}\n" + "\n".join(rows) + "\n"
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def _two_line_case(first: str, second: str) -> str:
|
| 949 |
+
return f"{first}\n{second}\n"
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def _interval_case(intervals: list[tuple[int, int]]) -> str:
|
| 953 |
+
rows = [f"{start} {end}" for start, end in intervals]
|
| 954 |
+
return f"{len(intervals)}\n" + "\n".join(rows) + "\n"
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def _coin_case(coins: list[int], target: int) -> str:
|
| 958 |
+
return f"{len(coins)} {target}\n{' '.join(str(coin) for coin in coins)}\n"
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def _fizzbuzz_case(n: int, a: int, b: int, label_a: str, label_b: str) -> str:
|
| 962 |
+
return f"{n} {a} {b}\n{label_a} {label_b}\n"
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _word_ladder_case(start: str, target: str, words: list[str]) -> str:
|
| 966 |
+
return f"{start} {target}\n{len(words)}\n{' '.join(words)}\n"
|
| 967 |
+
|
| 968 |
+
|
| 969 |
def _solve_sum_even_numbers(stdin: str) -> str:
|
| 970 |
_, numbers = _parse_int_array(stdin)
|
| 971 |
return str(sum(number for number in numbers if number % 2 == 0))
|
|
|
|
| 976 |
return str(max(numbers) - min(numbers))
|
| 977 |
|
| 978 |
|
| 979 |
+
def _solve_count_vowels(stdin: str) -> str:
|
| 980 |
+
text = stdin.rstrip("\n")
|
| 981 |
+
return str(sum(1 for char in text.lower() if char in "aeiou"))
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
def _solve_max_consecutive_ones(stdin: str) -> str:
|
| 985 |
+
binary = stdin.strip()
|
| 986 |
+
best = 0
|
| 987 |
+
current = 0
|
| 988 |
+
for char in binary:
|
| 989 |
+
if char == "1":
|
| 990 |
+
current += 1
|
| 991 |
+
best = max(best, current)
|
| 992 |
+
else:
|
| 993 |
+
current = 0
|
| 994 |
+
return str(best)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def _solve_fizzbuzz_variant(stdin: str) -> str:
|
| 998 |
+
(n, a, b), (label_a, label_b) = _parse_fizzbuzz(stdin)
|
| 999 |
+
output = []
|
| 1000 |
+
for value in range(1, n + 1):
|
| 1001 |
+
token = ""
|
| 1002 |
+
if value % a == 0:
|
| 1003 |
+
token += label_a
|
| 1004 |
+
if value % b == 0:
|
| 1005 |
+
token += label_b
|
| 1006 |
+
output.append(token or str(value))
|
| 1007 |
+
return " ".join(output)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def _solve_running_total(stdin: str) -> str:
|
| 1011 |
+
_, numbers = _parse_int_array(stdin)
|
| 1012 |
+
total = 0
|
| 1013 |
+
running = []
|
| 1014 |
+
for number in numbers:
|
| 1015 |
+
total += number
|
| 1016 |
+
running.append(str(total))
|
| 1017 |
+
return " ".join(running)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
def _solve_count_local_peaks(stdin: str) -> str:
|
| 1021 |
_, numbers = _parse_int_array(stdin)
|
| 1022 |
peaks = 0
|
|
|
|
| 1039 |
return str(best)
|
| 1040 |
|
| 1041 |
|
| 1042 |
+
def _solve_two_sum_count(stdin: str) -> str:
|
| 1043 |
+
_, target, numbers = _parse_target_array(stdin)
|
| 1044 |
+
counts: Counter[int] = Counter()
|
| 1045 |
+
pairs = 0
|
| 1046 |
+
for number in numbers:
|
| 1047 |
+
pairs += counts[target - number]
|
| 1048 |
+
counts[number] += 1
|
| 1049 |
+
return str(pairs)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def _solve_max_subarray_sum(stdin: str) -> str:
|
| 1053 |
+
_, numbers = _parse_int_array(stdin)
|
| 1054 |
+
best = numbers[0]
|
| 1055 |
+
current = numbers[0]
|
| 1056 |
+
for number in numbers[1:]:
|
| 1057 |
+
current = max(number, current + number)
|
| 1058 |
+
best = max(best, current)
|
| 1059 |
+
return str(best)
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
def _solve_group_anagrams_count(stdin: str) -> str:
|
| 1063 |
+
_, words = _parse_word_list(stdin)
|
| 1064 |
+
groups = {"".join(sorted(word)) for word in words}
|
| 1065 |
+
return str(len(groups))
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
def _solve_balanced_brackets(stdin: str) -> str:
|
| 1069 |
+
text = stdin.strip()
|
| 1070 |
+
pairs = {")": "(", "]": "[", "}": "{"}
|
| 1071 |
+
stack: list[str] = []
|
| 1072 |
+
for char in text:
|
| 1073 |
+
if char in "([{":
|
| 1074 |
+
stack.append(char)
|
| 1075 |
+
elif char in pairs:
|
| 1076 |
+
if not stack or stack.pop() != pairs[char]:
|
| 1077 |
+
return "NO"
|
| 1078 |
+
return "YES" if not stack else "NO"
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
def _solve_matrix_diagonal_sum(stdin: str) -> str:
|
| 1082 |
+
_, matrix = _parse_matrix(stdin)
|
| 1083 |
+
total = 0
|
| 1084 |
+
size = len(matrix)
|
| 1085 |
+
for index in range(size):
|
| 1086 |
+
total += matrix[index][index]
|
| 1087 |
+
mirrored = size - 1 - index
|
| 1088 |
+
if mirrored != index:
|
| 1089 |
+
total += matrix[index][mirrored]
|
| 1090 |
+
return str(total)
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
def _solve_smallest_most_frequent(stdin: str) -> str:
|
| 1094 |
_, numbers = _parse_int_array(stdin)
|
| 1095 |
+
counts = Counter(numbers)
|
|
|
|
|
|
|
| 1096 |
best_count = max(counts.values())
|
| 1097 |
best_value = min(number for number, count in counts.items() if count == best_count)
|
| 1098 |
return str(best_value)
|
|
|
|
| 1103 |
return " ".join(reversed(words))
|
| 1104 |
|
| 1105 |
|
| 1106 |
+
def _solve_longest_common_subsequence(stdin: str) -> str:
|
| 1107 |
+
left, right = _parse_two_strings(stdin)
|
| 1108 |
+
dp = [[0] * (len(right) + 1) for _ in range(len(left) + 1)]
|
| 1109 |
+
for i in range(1, len(left) + 1):
|
| 1110 |
+
for j in range(1, len(right) + 1):
|
| 1111 |
+
if left[i - 1] == right[j - 1]:
|
| 1112 |
+
dp[i][j] = dp[i - 1][j - 1] + 1
|
| 1113 |
+
else:
|
| 1114 |
+
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
| 1115 |
+
return str(dp[-1][-1])
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
def _solve_word_ladder_steps(stdin: str) -> str:
|
| 1119 |
+
start, target, words = _parse_word_ladder(stdin)
|
| 1120 |
+
if start == target:
|
| 1121 |
+
return "0"
|
| 1122 |
+
word_set = set(words)
|
| 1123 |
+
if target not in word_set:
|
| 1124 |
+
return "-1"
|
| 1125 |
+
|
| 1126 |
+
queue: deque[tuple[str, int]] = deque([(start, 0)])
|
| 1127 |
+
visited = {start}
|
| 1128 |
+
alphabet = "abcdefghijklmnopqrstuvwxyz"
|
| 1129 |
+
|
| 1130 |
+
while queue:
|
| 1131 |
+
current, steps = queue.popleft()
|
| 1132 |
+
for index in range(len(current)):
|
| 1133 |
+
for letter in alphabet:
|
| 1134 |
+
if letter == current[index]:
|
| 1135 |
+
continue
|
| 1136 |
+
candidate = current[:index] + letter + current[index + 1 :]
|
| 1137 |
+
if candidate == target:
|
| 1138 |
+
return str(steps + 1)
|
| 1139 |
+
if candidate in word_set and candidate not in visited:
|
| 1140 |
+
visited.add(candidate)
|
| 1141 |
+
queue.append((candidate, steps + 1))
|
| 1142 |
+
return "-1"
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def _solve_merge_intervals(stdin: str) -> str:
|
| 1146 |
+
intervals = _parse_intervals(stdin)
|
| 1147 |
+
ordered = sorted(intervals)
|
| 1148 |
+
merged: list[list[int]] = []
|
| 1149 |
+
for start, end in ordered:
|
| 1150 |
+
if not merged or start > merged[-1][1]:
|
| 1151 |
+
merged.append([start, end])
|
| 1152 |
+
else:
|
| 1153 |
+
merged[-1][1] = max(merged[-1][1], end)
|
| 1154 |
+
return str(len(merged))
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
def _solve_min_coins(stdin: str) -> str:
|
| 1158 |
+
_, target, coins = _parse_coin_problem(stdin)
|
| 1159 |
+
best = [target + 1] * (target + 1)
|
| 1160 |
+
best[0] = 0
|
| 1161 |
+
for value in range(1, target + 1):
|
| 1162 |
+
for coin in coins:
|
| 1163 |
+
if coin <= value:
|
| 1164 |
+
best[value] = min(best[value], best[value - coin] + 1)
|
| 1165 |
+
return str(best[target] if best[target] <= target else -1)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def _solve_rotate_matrix_90(stdin: str) -> str:
|
| 1169 |
+
_, matrix = _parse_matrix(stdin)
|
| 1170 |
+
size = len(matrix)
|
| 1171 |
+
rotated = [[matrix[size - 1 - row][col] for row in range(size)] for col in range(size)]
|
| 1172 |
+
flattened = [str(value) for row in rotated for value in row]
|
| 1173 |
+
return " ".join(flattened)
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
def _parse_int_array(stdin: str) -> tuple[int, list[int]]:
|
| 1177 |
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1178 |
n = int(lines[0])
|
|
|
|
| 1180 |
if len(numbers) != n:
|
| 1181 |
raise ValueError(f"Expected {n} integers, received {len(numbers)}")
|
| 1182 |
return n, numbers
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def _parse_target_array(stdin: str) -> tuple[int, int, list[int]]:
|
| 1186 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1187 |
+
n, target = map(int, lines[0].split())
|
| 1188 |
+
numbers = [int(part) for part in lines[1].split()]
|
| 1189 |
+
if len(numbers) != n:
|
| 1190 |
+
raise ValueError(f"Expected {n} integers, received {len(numbers)}")
|
| 1191 |
+
return n, target, numbers
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def _parse_word_list(stdin: str) -> tuple[int, list[str]]:
|
| 1195 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1196 |
+
n = int(lines[0])
|
| 1197 |
+
words = lines[1].split()
|
| 1198 |
+
if len(words) != n:
|
| 1199 |
+
raise ValueError(f"Expected {n} words, received {len(words)}")
|
| 1200 |
+
return n, words
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
def _parse_matrix(stdin: str) -> tuple[int, list[list[int]]]:
|
| 1204 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1205 |
+
n = int(lines[0])
|
| 1206 |
+
matrix = [[int(part) for part in line.split()] for line in lines[1 : n + 1]]
|
| 1207 |
+
if len(matrix) != n or any(len(row) != n for row in matrix):
|
| 1208 |
+
raise ValueError("Matrix dimensions do not match n.")
|
| 1209 |
+
return n, matrix
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def _parse_two_strings(stdin: str) -> tuple[str, str]:
|
| 1213 |
+
lines = stdin.strip().splitlines()
|
| 1214 |
+
if len(lines) < 2:
|
| 1215 |
+
raise ValueError("Expected two lines of text.")
|
| 1216 |
+
return lines[0].strip(), lines[1].strip()
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def _parse_fizzbuzz(stdin: str) -> tuple[tuple[int, int, int], tuple[str, str]]:
|
| 1220 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1221 |
+
n, a, b = map(int, lines[0].split())
|
| 1222 |
+
label_a, label_b = lines[1].split()
|
| 1223 |
+
return (n, a, b), (label_a, label_b)
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def _parse_word_ladder(stdin: str) -> tuple[str, str, list[str]]:
|
| 1227 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1228 |
+
start, target = lines[0].split()
|
| 1229 |
+
n = int(lines[1])
|
| 1230 |
+
words = lines[2].split()
|
| 1231 |
+
if len(words) != n:
|
| 1232 |
+
raise ValueError(f"Expected {n} words, received {len(words)}")
|
| 1233 |
+
return start, target, words
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
def _parse_intervals(stdin: str) -> list[tuple[int, int]]:
|
| 1237 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1238 |
+
n = int(lines[0])
|
| 1239 |
+
intervals = [tuple(map(int, line.split())) for line in lines[1 : n + 1]]
|
| 1240 |
+
if len(intervals) != n:
|
| 1241 |
+
raise ValueError("Interval count does not match n.")
|
| 1242 |
+
return [(start, end) for start, end in intervals]
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def _parse_coin_problem(stdin: str) -> tuple[int, int, list[int]]:
|
| 1246 |
+
lines = [line.strip() for line in stdin.strip().splitlines() if line.strip()]
|
| 1247 |
+
n, target = map(int, lines[0].split())
|
| 1248 |
+
coins = [int(part) for part in lines[1].split()]
|
| 1249 |
+
if len(coins) != n:
|
| 1250 |
+
raise ValueError(f"Expected {n} coins, received {len(coins)}")
|
| 1251 |
+
return n, target, coins
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
def _make_balanced_brackets(rng: random.Random, pairs: int) -> str:
|
| 1255 |
+
opens = ["(", "[", "{"]
|
| 1256 |
+
closing = {"(": ")", "[": "]", "{": "}"}
|
| 1257 |
+
stack: list[str] = []
|
| 1258 |
+
output: list[str] = []
|
| 1259 |
+
for _ in range(pairs * 2):
|
| 1260 |
+
can_open = len(stack) < pairs and (not stack or rng.random() < 0.6)
|
| 1261 |
+
if can_open:
|
| 1262 |
+
token = rng.choice(opens)
|
| 1263 |
+
stack.append(token)
|
| 1264 |
+
output.append(token)
|
| 1265 |
+
else:
|
| 1266 |
+
output.append(closing[stack.pop()])
|
| 1267 |
+
while stack:
|
| 1268 |
+
output.append(closing[stack.pop()])
|
| 1269 |
+
return "".join(output)
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
def _make_unbalanced_brackets(rng: random.Random, pairs: int) -> str:
|
| 1273 |
+
text = list(_make_balanced_brackets(rng, pairs))
|
| 1274 |
+
if not text:
|
| 1275 |
+
return "("
|
| 1276 |
+
mode = rng.choice(["swap", "drop", "flip"])
|
| 1277 |
+
if mode == "swap" and len(text) >= 2:
|
| 1278 |
+
index = rng.randrange(len(text) - 1)
|
| 1279 |
+
text[index], text[index + 1] = text[index + 1], text[index]
|
| 1280 |
+
elif mode == "drop":
|
| 1281 |
+
del text[rng.randrange(len(text))]
|
| 1282 |
+
else:
|
| 1283 |
+
replacements = ["(", ")", "[", "]", "{", "}"]
|
| 1284 |
+
text[rng.randrange(len(text))] = rng.choice(replacements)
|
| 1285 |
+
return "".join(text)
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
def _random_word(rng: random.Random, length: int) -> str:
|
| 1289 |
+
alphabet = "abcdefghijklmnopqrstuvwxyz"
|
| 1290 |
+
return "".join(rng.choice(alphabet) for _ in range(length))
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
def _build_word_ladder_path(rng: random.Random, length: int, steps: int) -> list[str]:
|
| 1294 |
+
alphabet = "abcdefghijklmnopqrstuvwxyz"
|
| 1295 |
+
current = _random_word(rng, length)
|
| 1296 |
+
path = [current]
|
| 1297 |
+
used = {current}
|
| 1298 |
+
while len(path) < steps + 1:
|
| 1299 |
+
chars = list(path[-1])
|
| 1300 |
+
index = rng.randrange(length)
|
| 1301 |
+
replacement = rng.choice(alphabet.replace(chars[index], ""))
|
| 1302 |
+
chars[index] = replacement
|
| 1303 |
+
candidate = "".join(chars)
|
| 1304 |
+
if candidate in used:
|
| 1305 |
+
continue
|
| 1306 |
+
used.add(candidate)
|
| 1307 |
+
path.append(candidate)
|
| 1308 |
+
return path
|
| 1309 |
+
|
| 1310 |
+
|
| 1311 |
+
def _build_word_ladder_extras(
|
| 1312 |
+
rng: random.Random,
|
| 1313 |
+
length: int,
|
| 1314 |
+
count: int,
|
| 1315 |
+
disallowed: set[str],
|
| 1316 |
+
) -> list[str]:
|
| 1317 |
+
words: list[str] = []
|
| 1318 |
+
seen = set(disallowed)
|
| 1319 |
+
while len(words) < count:
|
| 1320 |
+
candidate = _random_word(rng, length)
|
| 1321 |
+
if candidate in seen:
|
| 1322 |
+
continue
|
| 1323 |
+
seen.add(candidate)
|
| 1324 |
+
words.append(candidate)
|
| 1325 |
+
return words
|
env/test_cases.py
CHANGED
|
@@ -2,9 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
-
from env.generator import DIFFICULTY_LABELS, GeneratorAgent
|
| 6 |
-
|
| 7 |
-
VISIBLE_TEST_COUNT = 0
|
| 8 |
|
| 9 |
|
| 10 |
def load_problem_bank() -> list[dict[str, Any]]:
|
|
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
+
from env.generator import DIFFICULTY_LABELS, GeneratorAgent, VISIBLE_TEST_COUNT
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def load_problem_bank() -> list[dict[str, Any]]:
|
models.py
CHANGED
|
@@ -18,17 +18,22 @@ except ImportError:
|
|
| 18 |
episode_id: str = ""
|
| 19 |
step_count: int = 0
|
| 20 |
|
| 21 |
-
from pydantic import Field
|
| 22 |
-
|
| 23 |
|
| 24 |
class AdaptAction(Action):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
code: str = Field(..., min_length=1, description="Python code to execute.")
|
| 26 |
|
| 27 |
|
| 28 |
class AdaptObservation(Observation):
|
|
|
|
| 29 |
problem_id: str = Field(default="", description="Current problem identifier.")
|
| 30 |
problem_type: str = Field(default="", description="Current generated problem family.")
|
| 31 |
difficulty: str = Field(default="", description="Current curriculum difficulty tier.")
|
|
|
|
|
|
|
| 32 |
problem: str = Field(default="", description="Problem statement shown to the agent.")
|
| 33 |
input_format: str = Field(default="", description="Expected stdin format.")
|
| 34 |
constraints: str = Field(default="", description="Problem constraints.")
|
|
@@ -48,14 +53,17 @@ class AdaptObservation(Observation):
|
|
| 48 |
|
| 49 |
|
| 50 |
class AdaptState(State):
|
|
|
|
| 51 |
problem_id: str = Field(default="")
|
| 52 |
problem_type: str = Field(default="")
|
| 53 |
difficulty: str = Field(default="")
|
| 54 |
generator_mode: str = Field(default="heuristic")
|
| 55 |
-
|
|
|
|
| 56 |
last_reward: float = Field(default=0.0)
|
| 57 |
last_pass_rate: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 58 |
last_feedback: str = Field(default="")
|
|
|
|
| 59 |
generator_reward_signal: float = Field(default=0.0)
|
| 60 |
history: dict[str, Any] = Field(default_factory=dict)
|
| 61 |
recent_metrics: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
| 18 |
episode_id: str = ""
|
| 19 |
step_count: int = 0
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class AdaptAction(Action):
|
| 23 |
+
session_id: str = Field(
|
| 24 |
+
default="",
|
| 25 |
+
description="Environment session id for server-routed calls.",
|
| 26 |
+
)
|
| 27 |
code: str = Field(..., min_length=1, description="Python code to execute.")
|
| 28 |
|
| 29 |
|
| 30 |
class AdaptObservation(Observation):
|
| 31 |
+
session_id: str = Field(default="", description="Session id for the active environment instance.")
|
| 32 |
problem_id: str = Field(default="", description="Current problem identifier.")
|
| 33 |
problem_type: str = Field(default="", description="Current generated problem family.")
|
| 34 |
difficulty: str = Field(default="", description="Current curriculum difficulty tier.")
|
| 35 |
+
attempt_number: int = Field(default=0, ge=0, description="1-indexed attempt number within the episode.")
|
| 36 |
+
max_steps: int = Field(default=3, ge=1, description="Maximum attempts allowed for the episode.")
|
| 37 |
problem: str = Field(default="", description="Problem statement shown to the agent.")
|
| 38 |
input_format: str = Field(default="", description="Expected stdin format.")
|
| 39 |
constraints: str = Field(default="", description="Problem constraints.")
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
class AdaptState(State):
|
| 56 |
+
session_id: str = Field(default="")
|
| 57 |
problem_id: str = Field(default="")
|
| 58 |
problem_type: str = Field(default="")
|
| 59 |
difficulty: str = Field(default="")
|
| 60 |
generator_mode: str = Field(default="heuristic")
|
| 61 |
+
max_steps: int = Field(default=3, ge=1)
|
| 62 |
+
generated_problem: dict[str, Any] = Field(default_factory=dict)
|
| 63 |
last_reward: float = Field(default=0.0)
|
| 64 |
last_pass_rate: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 65 |
last_feedback: str = Field(default="")
|
| 66 |
+
last_execution_status: str = Field(default="ready")
|
| 67 |
generator_reward_signal: float = Field(default=0.0)
|
| 68 |
history: dict[str, Any] = Field(default_factory=dict)
|
| 69 |
recent_metrics: dict[str, Any] = Field(default_factory=dict)
|
openenv.yaml
CHANGED
|
@@ -1,35 +1,7 @@
|
|
| 1 |
spec_version: 1
|
| 2 |
-
name:
|
| 3 |
-
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
| 6 |
port: 7860
|
| 7 |
-
description: "
|
| 8 |
-
version: "0.2.0"
|
| 9 |
-
|
| 10 |
-
observation_space:
|
| 11 |
-
type: dict
|
| 12 |
-
description: "Problem prompt, examples, visible tests, difficulty metadata, reward, pass rates, execution status, and feedback."
|
| 13 |
-
|
| 14 |
-
action_space:
|
| 15 |
-
type: dict
|
| 16 |
-
description: "AdaptAction with a Python code string submitted for stdin/stdout evaluation."
|
| 17 |
-
|
| 18 |
-
reward_range: [0.0, 1.0]
|
| 19 |
-
|
| 20 |
-
tasks:
|
| 21 |
-
- name: easy_double
|
| 22 |
-
description: "Easy arithmetic stdin/stdout problem."
|
| 23 |
-
difficulty: easy
|
| 24 |
-
- name: easy_sum_two
|
| 25 |
-
description: "Easy two-integer arithmetic problem."
|
| 26 |
-
difficulty: easy
|
| 27 |
-
- name: medium_maximum
|
| 28 |
-
description: "Medium array scanning problem."
|
| 29 |
-
difficulty: medium
|
| 30 |
-
- name: medium_count_even
|
| 31 |
-
description: "Medium counting problem over a list."
|
| 32 |
-
difficulty: medium
|
| 33 |
-
- name: hard_reverse_words
|
| 34 |
-
description: "Harder string normalization and ordering problem."
|
| 35 |
-
difficulty: hard
|
|
|
|
| 1 |
spec_version: 1
|
| 2 |
+
name: adapt-dsa-tutor
|
| 3 |
+
version: "0.3.0"
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
| 6 |
port: 7860
|
| 7 |
+
description: "Adversarial DSA Programming Tutor - RL environment for training LLMs to solve algorithmic problems through adaptive curriculum and self-repair"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/test_env.py
CHANGED
|
@@ -7,7 +7,7 @@ ROOT = Path(__file__).resolve().parents[1]
|
|
| 7 |
if str(ROOT) not in sys.path:
|
| 8 |
sys.path.insert(0, str(ROOT))
|
| 9 |
|
| 10 |
-
from env.adapt_env import AdaptEnvironment
|
| 11 |
from env.generator import GeneratorAgent
|
| 12 |
from models import AdaptAction
|
| 13 |
|
|
@@ -23,10 +23,12 @@ def main() -> None:
|
|
| 23 |
env = AdaptEnvironment(generator=GeneratorAgent())
|
| 24 |
observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
|
| 25 |
assert observation.problem
|
|
|
|
| 26 |
assert observation.input_format
|
| 27 |
assert observation.constraints
|
| 28 |
assert observation.problem_type == "sum_even_numbers"
|
| 29 |
assert observation.execution_status == "ready"
|
|
|
|
| 30 |
assert_hidden_tests_are_not_exposed(observation.model_dump())
|
| 31 |
|
| 32 |
correct = env.step(
|
|
@@ -39,11 +41,13 @@ def main() -> None:
|
|
| 39 |
)
|
| 40 |
)
|
| 41 |
print(correct)
|
| 42 |
-
assert correct.reward
|
| 43 |
assert correct.pass_rate == 1.0
|
| 44 |
assert correct.execution_status == "completed"
|
|
|
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
AdaptAction(
|
| 48 |
code=(
|
| 49 |
"n=int(input())\n"
|
|
@@ -52,41 +56,64 @@ def main() -> None:
|
|
| 52 |
)
|
| 53 |
)
|
| 54 |
)
|
| 55 |
-
print(
|
| 56 |
-
assert
|
| 57 |
-
assert
|
| 58 |
-
assert
|
| 59 |
|
| 60 |
-
|
| 61 |
AdaptAction(
|
| 62 |
code=(
|
| 63 |
"n=int(input())\n"
|
| 64 |
-
"input()\n"
|
| 65 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
)
|
| 68 |
)
|
| 69 |
-
print(
|
| 70 |
-
assert
|
| 71 |
-
assert
|
|
|
|
|
|
|
| 72 |
|
|
|
|
| 73 |
syntax = env.step(AdaptAction(code="def broken(:\n pass"))
|
| 74 |
print(syntax)
|
| 75 |
assert syntax.reward == 0.0
|
|
|
|
| 76 |
assert syntax.execution_status == "syntax_error"
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
timeout = env.step(AdaptAction(code="while True:\n pass"))
|
| 79 |
print(timeout)
|
| 80 |
assert timeout.timeout_count > 0
|
| 81 |
assert timeout.execution_status == "timeout"
|
|
|
|
| 82 |
|
|
|
|
| 83 |
unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
|
| 84 |
print(unsafe)
|
| 85 |
assert unsafe.reward == 0.0
|
| 86 |
assert unsafe.execution_status == "safety_violation"
|
|
|
|
| 87 |
|
| 88 |
-
assert env.state.
|
| 89 |
-
assert env.state.history["recent_pass_rates"]
|
| 90 |
assert_hidden_tests_are_not_exposed(timeout.model_dump())
|
| 91 |
print("ADAPT OpenEnv smoke tests passed")
|
| 92 |
|
|
|
|
| 7 |
if str(ROOT) not in sys.path:
|
| 8 |
sys.path.insert(0, str(ROOT))
|
| 9 |
|
| 10 |
+
from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
|
| 11 |
from env.generator import GeneratorAgent
|
| 12 |
from models import AdaptAction
|
| 13 |
|
|
|
|
| 23 |
env = AdaptEnvironment(generator=GeneratorAgent())
|
| 24 |
observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
|
| 25 |
assert observation.problem
|
| 26 |
+
assert "Examples:" in observation.problem
|
| 27 |
assert observation.input_format
|
| 28 |
assert observation.constraints
|
| 29 |
assert observation.problem_type == "sum_even_numbers"
|
| 30 |
assert observation.execution_status == "ready"
|
| 31 |
+
assert observation.max_steps == MAX_STEPS_PER_EPISODE
|
| 32 |
assert_hidden_tests_are_not_exposed(observation.model_dump())
|
| 33 |
|
| 34 |
correct = env.step(
|
|
|
|
| 41 |
)
|
| 42 |
)
|
| 43 |
print(correct)
|
| 44 |
+
assert correct.reward == 1.0, correct.model_dump()
|
| 45 |
assert correct.pass_rate == 1.0
|
| 46 |
assert correct.execution_status == "completed"
|
| 47 |
+
assert correct.done is True
|
| 48 |
|
| 49 |
+
observation = env.reset(problem_id="running_total", difficulty="easy")
|
| 50 |
+
repair_1 = env.step(
|
| 51 |
AdaptAction(
|
| 52 |
code=(
|
| 53 |
"n=int(input())\n"
|
|
|
|
| 56 |
)
|
| 57 |
)
|
| 58 |
)
|
| 59 |
+
print(repair_1)
|
| 60 |
+
assert repair_1.done is False
|
| 61 |
+
assert repair_1.execution_status in {"wrong_answer", "runtime_error", "invalid_output_format"}
|
| 62 |
+
assert "Previous attempt status: ready" in repair_1.feedback
|
| 63 |
|
| 64 |
+
repair_2 = env.step(
|
| 65 |
AdaptAction(
|
| 66 |
code=(
|
| 67 |
"n=int(input())\n"
|
| 68 |
+
"nums=list(map(int,input().split()))\n"
|
| 69 |
+
"running=0\n"
|
| 70 |
+
"out=[]\n"
|
| 71 |
+
"for x in nums:\n"
|
| 72 |
+
" running += x\n"
|
| 73 |
+
" out.append(str(running))\n"
|
| 74 |
+
"print(' '.join(out))"
|
| 75 |
)
|
| 76 |
)
|
| 77 |
)
|
| 78 |
+
print(repair_2)
|
| 79 |
+
assert repair_2.done is True
|
| 80 |
+
assert repair_2.pass_rate == 1.0
|
| 81 |
+
assert repair_2.reward == 0.85
|
| 82 |
+
assert "Previous attempt status:" in repair_2.feedback
|
| 83 |
|
| 84 |
+
observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
|
| 85 |
syntax = env.step(AdaptAction(code="def broken(:\n pass"))
|
| 86 |
print(syntax)
|
| 87 |
assert syntax.reward == 0.0
|
| 88 |
+
assert syntax.done is False
|
| 89 |
assert syntax.execution_status == "syntax_error"
|
| 90 |
|
| 91 |
+
runtime = env.step(
|
| 92 |
+
AdaptAction(
|
| 93 |
+
code=(
|
| 94 |
+
"n=int(input())\n"
|
| 95 |
+
"nums=list(map(int,input().split()))\n"
|
| 96 |
+
"print(nums[n])"
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
print(runtime)
|
| 101 |
+
assert runtime.execution_status == "runtime_error"
|
| 102 |
+
|
| 103 |
timeout = env.step(AdaptAction(code="while True:\n pass"))
|
| 104 |
print(timeout)
|
| 105 |
assert timeout.timeout_count > 0
|
| 106 |
assert timeout.execution_status == "timeout"
|
| 107 |
+
assert timeout.done is True
|
| 108 |
|
| 109 |
+
observation = env.reset(problem_id="sum_even_numbers", difficulty="easy")
|
| 110 |
unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
|
| 111 |
print(unsafe)
|
| 112 |
assert unsafe.reward == 0.0
|
| 113 |
assert unsafe.execution_status == "safety_violation"
|
| 114 |
+
assert unsafe.done is False
|
| 115 |
|
| 116 |
+
assert env.state.history["attempts"]
|
|
|
|
| 117 |
assert_hidden_tests_are_not_exposed(timeout.model_dump())
|
| 118 |
print("ADAPT OpenEnv smoke tests passed")
|
| 119 |
|
server/app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import argparse
|
|
|
|
| 4 |
from typing import Any
|
|
|
|
| 5 |
|
| 6 |
import uvicorn
|
| 7 |
-
from fastapi import Body, FastAPI, HTTPException, Request
|
| 8 |
from fastapi.responses import RedirectResponse, Response
|
| 9 |
from pydantic import BaseModel
|
| 10 |
|
|
@@ -12,11 +14,15 @@ from env.adapt_env import AdaptEnvironment
|
|
| 12 |
from env.test_cases import load_problem_bank
|
| 13 |
from models import AdaptAction, AdaptObservation, AdaptState
|
| 14 |
|
| 15 |
-
ENV_NAME = "
|
| 16 |
ENV_DESCRIPTION = (
|
| 17 |
-
"
|
| 18 |
-
"
|
| 19 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
TASKS = [
|
| 21 |
{
|
| 22 |
"name": problem["problem_id"],
|
|
@@ -26,11 +32,11 @@ TASKS = [
|
|
| 26 |
for problem in load_problem_bank()
|
| 27 |
]
|
| 28 |
|
| 29 |
-
app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version=
|
| 30 |
-
ENV = AdaptEnvironment()
|
| 31 |
|
| 32 |
|
| 33 |
class ResetRequest(BaseModel):
|
|
|
|
| 34 |
seed: int | None = None
|
| 35 |
episode_id: str | None = None
|
| 36 |
problem_id: str | None = None
|
|
@@ -41,16 +47,47 @@ def _metadata() -> dict[str, Any]:
|
|
| 41 |
return {
|
| 42 |
"name": ENV_NAME,
|
| 43 |
"description": ENV_DESCRIPTION,
|
| 44 |
-
"version":
|
| 45 |
"tasks": TASKS,
|
| 46 |
"mode": "simulation",
|
| 47 |
}
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@app.get("/")
|
| 51 |
def root() -> dict[str, Any]:
|
|
|
|
| 52 |
payload = _metadata()
|
| 53 |
payload["status"] = "ok"
|
|
|
|
| 54 |
return payload
|
| 55 |
|
| 56 |
|
|
@@ -70,22 +107,26 @@ def favicon() -> Response:
|
|
| 70 |
|
| 71 |
|
| 72 |
@app.get("/health")
|
| 73 |
-
def health() -> dict[str,
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
@app.get("/metadata")
|
| 78 |
def metadata() -> dict[str, Any]:
|
|
|
|
| 79 |
return _metadata()
|
| 80 |
|
| 81 |
|
| 82 |
@app.get("/tasks")
|
| 83 |
def list_tasks() -> dict[str, Any]:
|
|
|
|
| 84 |
return {"tasks": TASKS}
|
| 85 |
|
| 86 |
|
| 87 |
@app.get("/schema")
|
| 88 |
def schema() -> dict[str, Any]:
|
|
|
|
| 89 |
return {
|
| 90 |
"action": AdaptAction.model_json_schema(),
|
| 91 |
"observation": AdaptObservation.model_json_schema(),
|
|
@@ -95,6 +136,7 @@ def schema() -> dict[str, Any]:
|
|
| 95 |
|
| 96 |
@app.post("/mcp")
|
| 97 |
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
|
|
|
|
| 98 |
return {
|
| 99 |
"jsonrpc": "2.0",
|
| 100 |
"id": payload.get("id"),
|
|
@@ -107,8 +149,14 @@ def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
|
|
| 107 |
|
| 108 |
@app.post("/reset")
|
| 109 |
def reset(request: ResetRequest | None = None) -> dict[str, Any]:
|
|
|
|
| 110 |
effective_request = request or ResetRequest()
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
seed=effective_request.seed,
|
| 113 |
episode_id=effective_request.episode_id,
|
| 114 |
problem_id=effective_request.problem_id,
|
|
@@ -119,6 +167,7 @@ def reset(request: ResetRequest | None = None) -> dict[str, Any]:
|
|
| 119 |
|
| 120 |
@app.post("/step")
|
| 121 |
async def step(request: Request) -> dict[str, Any]:
|
|
|
|
| 122 |
payload = await request.json()
|
| 123 |
if not isinstance(payload, dict):
|
| 124 |
raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
|
|
@@ -129,24 +178,31 @@ async def step(request: Request) -> dict[str, Any]:
|
|
| 129 |
except Exception as exc:
|
| 130 |
raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return {
|
| 134 |
"observation": observation.model_dump(),
|
| 135 |
"reward": float(observation.reward),
|
| 136 |
"done": bool(observation.done),
|
| 137 |
"info": {
|
|
|
|
| 138 |
"feedback": observation.feedback,
|
| 139 |
"pass_rate": observation.pass_rate,
|
|
|
|
| 140 |
"execution_status": observation.execution_status,
|
| 141 |
},
|
| 142 |
}
|
| 143 |
|
| 144 |
|
| 145 |
@app.get("/state")
|
| 146 |
-
def state() -> dict[str, Any]:
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def main(host: str | None = None, port: int | None = None) -> None:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
+
from datetime import datetime, timedelta, timezone
|
| 5 |
from typing import Any
|
| 6 |
+
from uuid import uuid4
|
| 7 |
|
| 8 |
import uvicorn
|
| 9 |
+
from fastapi import Body, FastAPI, HTTPException, Query, Request
|
| 10 |
from fastapi.responses import RedirectResponse, Response
|
| 11 |
from pydantic import BaseModel
|
| 12 |
|
|
|
|
| 14 |
from env.test_cases import load_problem_bank
|
| 15 |
from models import AdaptAction, AdaptObservation, AdaptState
|
| 16 |
|
| 17 |
+
ENV_NAME = "adapt-dsa-tutor"
|
| 18 |
ENV_DESCRIPTION = (
|
| 19 |
+
"Adversarial DSA Programming Tutor - RL environment for training LLMs to solve "
|
| 20 |
+
"algorithmic problems through adaptive curriculum and self-repair."
|
| 21 |
)
|
| 22 |
+
ENV_VERSION = "0.3.0"
|
| 23 |
+
SESSION_TTL = timedelta(minutes=30)
|
| 24 |
+
SESSIONS: dict[str, AdaptEnvironment] = {}
|
| 25 |
+
SESSION_LAST_ACCESSED: dict[str, datetime] = {}
|
| 26 |
TASKS = [
|
| 27 |
{
|
| 28 |
"name": problem["problem_id"],
|
|
|
|
| 32 |
for problem in load_problem_bank()
|
| 33 |
]
|
| 34 |
|
| 35 |
+
app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version=ENV_VERSION)
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class ResetRequest(BaseModel):
|
| 39 |
+
session_id: str | None = None
|
| 40 |
seed: int | None = None
|
| 41 |
episode_id: str | None = None
|
| 42 |
problem_id: str | None = None
|
|
|
|
| 47 |
return {
|
| 48 |
"name": ENV_NAME,
|
| 49 |
"description": ENV_DESCRIPTION,
|
| 50 |
+
"version": ENV_VERSION,
|
| 51 |
"tasks": TASKS,
|
| 52 |
"mode": "simulation",
|
| 53 |
}
|
| 54 |
|
| 55 |
|
| 56 |
+
def _utc_now() -> datetime:
|
| 57 |
+
return datetime.now(timezone.utc)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _cleanup_sessions() -> None:
|
| 61 |
+
now = _utc_now()
|
| 62 |
+
expired = [
|
| 63 |
+
session_id
|
| 64 |
+
for session_id, last_seen in SESSION_LAST_ACCESSED.items()
|
| 65 |
+
if now - last_seen > SESSION_TTL
|
| 66 |
+
]
|
| 67 |
+
for session_id in expired:
|
| 68 |
+
SESSIONS.pop(session_id, None)
|
| 69 |
+
SESSION_LAST_ACCESSED.pop(session_id, None)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _touch_session(session_id: str) -> None:
|
| 73 |
+
SESSION_LAST_ACCESSED[session_id] = _utc_now()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _require_session(session_id: str) -> AdaptEnvironment:
|
| 77 |
+
_cleanup_sessions()
|
| 78 |
+
env = SESSIONS.get(session_id)
|
| 79 |
+
if env is None:
|
| 80 |
+
raise HTTPException(status_code=404, detail=f"Unknown or expired session_id: {session_id}")
|
| 81 |
+
_touch_session(session_id)
|
| 82 |
+
return env
|
| 83 |
+
|
| 84 |
+
|
| 85 |
@app.get("/")
|
| 86 |
def root() -> dict[str, Any]:
|
| 87 |
+
_cleanup_sessions()
|
| 88 |
payload = _metadata()
|
| 89 |
payload["status"] = "ok"
|
| 90 |
+
payload["active_sessions"] = len(SESSIONS)
|
| 91 |
return payload
|
| 92 |
|
| 93 |
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
@app.get("/health")
|
| 110 |
+
def health() -> dict[str, Any]:
|
| 111 |
+
_cleanup_sessions()
|
| 112 |
+
return {"status": "healthy", "active_sessions": len(SESSIONS)}
|
| 113 |
|
| 114 |
|
| 115 |
@app.get("/metadata")
|
| 116 |
def metadata() -> dict[str, Any]:
|
| 117 |
+
_cleanup_sessions()
|
| 118 |
return _metadata()
|
| 119 |
|
| 120 |
|
| 121 |
@app.get("/tasks")
|
| 122 |
def list_tasks() -> dict[str, Any]:
|
| 123 |
+
_cleanup_sessions()
|
| 124 |
return {"tasks": TASKS}
|
| 125 |
|
| 126 |
|
| 127 |
@app.get("/schema")
|
| 128 |
def schema() -> dict[str, Any]:
|
| 129 |
+
_cleanup_sessions()
|
| 130 |
return {
|
| 131 |
"action": AdaptAction.model_json_schema(),
|
| 132 |
"observation": AdaptObservation.model_json_schema(),
|
|
|
|
| 136 |
|
| 137 |
@app.post("/mcp")
|
| 138 |
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
|
| 139 |
+
_cleanup_sessions()
|
| 140 |
return {
|
| 141 |
"jsonrpc": "2.0",
|
| 142 |
"id": payload.get("id"),
|
|
|
|
| 149 |
|
| 150 |
@app.post("/reset")
|
| 151 |
def reset(request: ResetRequest | None = None) -> dict[str, Any]:
|
| 152 |
+
_cleanup_sessions()
|
| 153 |
effective_request = request or ResetRequest()
|
| 154 |
+
session_id = effective_request.session_id or str(uuid4())
|
| 155 |
+
env = AdaptEnvironment(session_id=session_id)
|
| 156 |
+
SESSIONS[session_id] = env
|
| 157 |
+
_touch_session(session_id)
|
| 158 |
+
observation = env.reset(
|
| 159 |
+
session_id=session_id,
|
| 160 |
seed=effective_request.seed,
|
| 161 |
episode_id=effective_request.episode_id,
|
| 162 |
problem_id=effective_request.problem_id,
|
|
|
|
| 167 |
|
| 168 |
@app.post("/step")
|
| 169 |
async def step(request: Request) -> dict[str, Any]:
|
| 170 |
+
_cleanup_sessions()
|
| 171 |
payload = await request.json()
|
| 172 |
if not isinstance(payload, dict):
|
| 173 |
raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
|
|
|
|
| 178 |
except Exception as exc:
|
| 179 |
raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
|
| 180 |
|
| 181 |
+
if not effective_action.session_id:
|
| 182 |
+
raise HTTPException(status_code=422, detail="`session_id` is required in the /step request body.")
|
| 183 |
+
|
| 184 |
+
env = _require_session(effective_action.session_id)
|
| 185 |
+
observation = env.step(effective_action)
|
| 186 |
return {
|
| 187 |
"observation": observation.model_dump(),
|
| 188 |
"reward": float(observation.reward),
|
| 189 |
"done": bool(observation.done),
|
| 190 |
"info": {
|
| 191 |
+
"session_id": observation.session_id,
|
| 192 |
"feedback": observation.feedback,
|
| 193 |
"pass_rate": observation.pass_rate,
|
| 194 |
+
"visible_pass_rate": observation.visible_pass_rate,
|
| 195 |
"execution_status": observation.execution_status,
|
| 196 |
},
|
| 197 |
}
|
| 198 |
|
| 199 |
|
| 200 |
@app.get("/state")
|
| 201 |
+
def state(session_id: str = Query(..., description="Session id returned from /reset.")) -> dict[str, Any]:
|
| 202 |
+
env = _require_session(session_id)
|
| 203 |
+
if not env.problem:
|
| 204 |
+
env.reset(session_id=session_id)
|
| 205 |
+
return env.state.model_dump()
|
| 206 |
|
| 207 |
|
| 208 |
def main(host: str | None = None, port: int | None = None) -> None:
|
training/plot_results.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def read_rows(csv_path: Path) -> list[dict[str, Any]]:
|
| 11 |
+
with csv_path.open("r", encoding="utf-8", newline="") as handle:
|
| 12 |
+
reader = csv.DictReader(handle)
|
| 13 |
+
rows: list[dict[str, Any]] = []
|
| 14 |
+
for row in reader:
|
| 15 |
+
parsed: dict[str, Any] = {}
|
| 16 |
+
for key, value in row.items():
|
| 17 |
+
if value is None:
|
| 18 |
+
parsed[key] = value
|
| 19 |
+
continue
|
| 20 |
+
value = value.strip()
|
| 21 |
+
if value == "":
|
| 22 |
+
parsed[key] = value
|
| 23 |
+
continue
|
| 24 |
+
try:
|
| 25 |
+
parsed[key] = float(value) if "." in value else int(value)
|
| 26 |
+
except ValueError:
|
| 27 |
+
parsed[key] = value
|
| 28 |
+
rows.append(parsed)
|
| 29 |
+
return rows
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def rolling_mean(values: list[float], window: int) -> list[float]:
|
| 33 |
+
output: list[float] = []
|
| 34 |
+
for index in range(len(values)):
|
| 35 |
+
start = max(0, index - window + 1)
|
| 36 |
+
chunk = values[start : index + 1]
|
| 37 |
+
output.append(sum(chunk) / len(chunk))
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def plot_reward_curve(rows: list[dict[str, Any]], output_dir: Path) -> None:
|
| 42 |
+
import matplotlib.pyplot as plt
|
| 43 |
+
|
| 44 |
+
train_rows = [row for row in rows if row.get("phase") == "train"]
|
| 45 |
+
steps = [int(row["step"]) for row in train_rows]
|
| 46 |
+
rewards = [float(row["episode_reward"]) for row in train_rows]
|
| 47 |
+
reward_smooth = rolling_mean(rewards, window=20)
|
| 48 |
+
|
| 49 |
+
plt.figure(figsize=(10, 5))
|
| 50 |
+
plt.plot(steps, rewards, alpha=0.25, label="Episode reward")
|
| 51 |
+
plt.plot(steps, reward_smooth, linewidth=2, label="20-step moving average")
|
| 52 |
+
plt.xlabel("Training step")
|
| 53 |
+
plt.ylabel("Reward")
|
| 54 |
+
plt.title("ADAPT Training Reward Curve")
|
| 55 |
+
plt.legend()
|
| 56 |
+
plt.tight_layout()
|
| 57 |
+
plt.savefig(output_dir / "reward_curve.png", dpi=200)
|
| 58 |
+
plt.close()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def plot_pass_rate_by_difficulty(rows: list[dict[str, Any]], output_dir: Path) -> None:
|
| 62 |
+
import matplotlib.pyplot as plt
|
| 63 |
+
|
| 64 |
+
train_rows = [row for row in rows if row.get("phase") == "train"]
|
| 65 |
+
grouped: dict[str, list[tuple[int, float]]] = defaultdict(list)
|
| 66 |
+
for row in train_rows:
|
| 67 |
+
grouped[str(row["difficulty_tier"])].append((int(row["step"]), float(row["pass_rate"])))
|
| 68 |
+
|
| 69 |
+
plt.figure(figsize=(10, 5))
|
| 70 |
+
for difficulty in ("easy", "medium", "hard"):
|
| 71 |
+
points = grouped.get(difficulty, [])
|
| 72 |
+
if not points:
|
| 73 |
+
continue
|
| 74 |
+
steps = [step for step, _ in points]
|
| 75 |
+
values = [value for _, value in points]
|
| 76 |
+
smooth = rolling_mean(values, window=10)
|
| 77 |
+
plt.plot(steps, smooth, linewidth=2, label=difficulty.title())
|
| 78 |
+
|
| 79 |
+
plt.xlabel("Training step")
|
| 80 |
+
plt.ylabel("Pass rate")
|
| 81 |
+
plt.title("Pass Rate by Difficulty Tier")
|
| 82 |
+
plt.legend()
|
| 83 |
+
plt.tight_layout()
|
| 84 |
+
plt.savefig(output_dir / "pass_rate_by_difficulty.png", dpi=200)
|
| 85 |
+
plt.close()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def plot_family_productivity(rows: list[dict[str, Any]], output_dir: Path) -> None:
|
| 89 |
+
import matplotlib.pyplot as plt
|
| 90 |
+
|
| 91 |
+
train_rows = [row for row in rows if row.get("phase") == "train"]
|
| 92 |
+
productivity_columns = [key for key in train_rows[0].keys() if str(key).startswith("family_productivity__")]
|
| 93 |
+
if not productivity_columns:
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
ranked_columns = sorted(
|
| 97 |
+
productivity_columns,
|
| 98 |
+
key=lambda column: float(train_rows[-1].get(column, 0.0)),
|
| 99 |
+
reverse=True,
|
| 100 |
+
)[:8]
|
| 101 |
+
|
| 102 |
+
plt.figure(figsize=(11, 6))
|
| 103 |
+
steps = [int(row["step"]) for row in train_rows]
|
| 104 |
+
for column in ranked_columns:
|
| 105 |
+
family = column.split("__", 1)[1]
|
| 106 |
+
values = [float(row.get(column, 0.0)) for row in train_rows]
|
| 107 |
+
plt.plot(steps, values, linewidth=2, label=family)
|
| 108 |
+
|
| 109 |
+
plt.xlabel("Training step")
|
| 110 |
+
plt.ylabel("Family productivity EMA")
|
| 111 |
+
plt.title("Reward-Aware Family Productivity Over Training")
|
| 112 |
+
plt.legend(loc="upper left", fontsize=8)
|
| 113 |
+
plt.tight_layout()
|
| 114 |
+
plt.savefig(output_dir / "family_productivity.png", dpi=200)
|
| 115 |
+
plt.close()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def main(argv: list[str] | None = None) -> None:
|
| 119 |
+
parser = argparse.ArgumentParser(description="Plot ADAPT reward and curriculum artifacts from reward_curve.csv.")
|
| 120 |
+
parser.add_argument("csv_path", help="Path to reward_curve.csv")
|
| 121 |
+
parser.add_argument("--output-dir", default=None, help="Directory for PNG outputs. Defaults to the CSV directory.")
|
| 122 |
+
args = parser.parse_args(argv)
|
| 123 |
+
|
| 124 |
+
csv_path = Path(args.csv_path)
|
| 125 |
+
output_dir = Path(args.output_dir) if args.output_dir else csv_path.parent
|
| 126 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
|
| 128 |
+
rows = read_rows(csv_path)
|
| 129 |
+
if not rows:
|
| 130 |
+
raise RuntimeError(f"No rows found in {csv_path}")
|
| 131 |
+
|
| 132 |
+
plot_reward_curve(rows, output_dir)
|
| 133 |
+
plot_pass_rate_by_difficulty(rows, output_dir)
|
| 134 |
+
plot_family_productivity(rows, output_dir)
|
| 135 |
+
print(f"Saved plots to {output_dir}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
training/train_grpo.py
CHANGED
|
@@ -1,14 +1,25 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import argparse
|
|
|
|
| 4 |
import json
|
|
|
|
| 5 |
from dataclasses import dataclass, field
|
|
|
|
| 6 |
from typing import Any
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
from models import AdaptAction
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def extract_code(completion: str) -> str:
|
| 14 |
text = completion.strip()
|
|
@@ -19,21 +30,45 @@ def extract_code(completion: str) -> str:
|
|
| 19 |
return text
|
| 20 |
|
| 21 |
|
| 22 |
-
def
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"problem_id": problem["problem_id"],
|
|
|
|
| 25 |
"difficulty": problem["difficulty_label"],
|
| 26 |
-
"
|
|
|
|
|
|
|
| 27 |
"input_format": problem["input_format"],
|
| 28 |
"constraints": problem["constraints"],
|
|
|
|
| 29 |
}
|
| 30 |
-
return (
|
| 31 |
-
"You are the Solver Agent for ADAPT.\n"
|
| 32 |
-
"Read the generated DSA task and reply with only runnable Python code.\n"
|
| 33 |
-
"The program must read from stdin and print to stdout.\n"
|
| 34 |
-
"No markdown, no explanation.\n\n"
|
| 35 |
-
f"{json.dumps(public_problem, indent=2)}"
|
| 36 |
-
)
|
| 37 |
|
| 38 |
|
| 39 |
@dataclass
|
|
@@ -42,29 +77,37 @@ class CurriculumManager:
|
|
| 42 |
current_idx: int = 0
|
| 43 |
success_history: list[float] = field(default_factory=list)
|
| 44 |
window_size: int = 10
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def current_difficulty(self) -> str:
|
| 47 |
return self.difficulties[self.current_idx]
|
| 48 |
|
| 49 |
-
def
|
| 50 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 51 |
if len(self.success_history) > self.window_size:
|
| 52 |
self.success_history.pop(0)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
moving_average = sum(self.success_history) / len(self.success_history)
|
| 55 |
-
if moving_average >
|
| 56 |
self.current_idx += 1
|
| 57 |
self.success_history.clear()
|
| 58 |
print(
|
| 59 |
f"[curriculum] promoted to {self.current_difficulty()} "
|
| 60 |
-
f"(
|
| 61 |
)
|
| 62 |
-
elif moving_average <
|
| 63 |
self.current_idx -= 1
|
| 64 |
self.success_history.clear()
|
| 65 |
print(
|
| 66 |
-
f"[curriculum]
|
| 67 |
-
f"(
|
| 68 |
)
|
| 69 |
|
| 70 |
|
|
@@ -72,6 +115,7 @@ class CurriculumManager:
|
|
| 72 |
class GeneratorController:
|
| 73 |
mode: str = "heuristic"
|
| 74 |
deterministic: bool = True
|
|
|
|
| 75 |
generator: GeneratorAgent = field(init=False)
|
| 76 |
history: dict[str, Any] = field(
|
| 77 |
default_factory=lambda: {
|
|
@@ -83,35 +127,80 @@ class GeneratorController:
|
|
| 83 |
}
|
| 84 |
)
|
| 85 |
prompt_registry: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
|
|
| 86 |
|
| 87 |
def __post_init__(self) -> None:
|
| 88 |
self.generator = GeneratorAgent(deterministic=self.deterministic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def create_rollout_problem(self, difficulty: str) -> tuple[str, dict[str, Any]]:
|
| 91 |
-
problem = self.
|
| 92 |
-
prompt =
|
| 93 |
self.prompt_registry[prompt] = problem
|
| 94 |
return prompt, problem
|
| 95 |
|
| 96 |
def resolve_prompt(self, prompt: str) -> dict[str, Any]:
|
| 97 |
if prompt not in self.prompt_registry:
|
| 98 |
raise KeyError("Prompt was not registered with the generator controller.")
|
| 99 |
-
return self.prompt_registry
|
| 100 |
-
|
| 101 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
|
| 103 |
self.history["problem_types"].append(problem.get("problem_type", ""))
|
| 104 |
self.history["problem_signatures"].append(problem.get("problem_id", ""))
|
| 105 |
-
|
| 106 |
-
self.history["generator_rewards"].append(round(float(generator_reward_signal), 4))
|
| 107 |
-
else:
|
| 108 |
-
self.history["generator_rewards"].append(0.0)
|
| 109 |
self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
|
| 112 |
values = self.history[key]
|
| 113 |
-
if len(values) >
|
| 114 |
-
del values[:-
|
| 115 |
|
| 116 |
def stats_snapshot(self) -> dict[str, Any]:
|
| 117 |
return {
|
|
@@ -120,6 +209,13 @@ class GeneratorController:
|
|
| 120 |
"recent_pass_rates": list(self.history["recent_pass_rates"][-5:]),
|
| 121 |
"recent_problem_types": list(self.history["problem_types"][-5:]),
|
| 122 |
"recent_generator_rewards": list(self.history["generator_rewards"][-5:]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
|
|
@@ -138,50 +234,273 @@ class GeneratorRolloutDataset:
|
|
| 138 |
return {"prompt": prompt}
|
| 139 |
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def reward_func(prompts, completions, **kwargs) -> list[float]:
|
| 143 |
del kwargs
|
| 144 |
-
env = AdaptEnvironment(generator=controller.generator, generator_mode=controller.mode)
|
| 145 |
rewards: list[float] = []
|
| 146 |
-
pass_rates: list[float] = []
|
| 147 |
|
| 148 |
for prompt, completion in zip(prompts, completions):
|
| 149 |
problem = controller.resolve_prompt(prompt)
|
|
|
|
| 150 |
env.reset(
|
| 151 |
difficulty=problem["difficulty_label"],
|
| 152 |
generated_problem=problem,
|
| 153 |
generator_mode=controller.mode,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
-
observation = env.step(AdaptAction(code=extract_code(completion)))
|
| 156 |
rewards.append(float(observation.reward))
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
json.dumps(
|
| 162 |
-
{
|
| 163 |
-
"problem_id": problem["problem_id"],
|
| 164 |
-
"problem_type": problem["problem_type"],
|
| 165 |
-
"difficulty": problem["difficulty_label"],
|
| 166 |
-
"solver_reward": observation.reward,
|
| 167 |
-
"pass_rate": observation.pass_rate,
|
| 168 |
-
"generator_reward": observation.generator_reward_signal,
|
| 169 |
-
"status": observation.execution_status,
|
| 170 |
-
}
|
| 171 |
-
),
|
| 172 |
)
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
return rewards
|
| 179 |
|
| 180 |
return reward_func
|
| 181 |
|
| 182 |
|
| 183 |
-
def
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
def run_training(args: argparse.Namespace) -> None:
|
|
@@ -193,6 +512,9 @@ def run_training(args: argparse.Namespace) -> None:
|
|
| 193 |
"Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
|
| 194 |
) from exc
|
| 195 |
|
|
|
|
|
|
|
|
|
|
| 196 |
PatchFastRL("GRPO", FastLanguageModel)
|
| 197 |
|
| 198 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
@@ -200,6 +522,8 @@ def run_training(args: argparse.Namespace) -> None:
|
|
| 200 |
max_seq_length=args.max_seq_length,
|
| 201 |
load_in_4bit=not args.disable_4bit,
|
| 202 |
)
|
|
|
|
|
|
|
| 203 |
|
| 204 |
model = FastLanguageModel.get_peft_model(
|
| 205 |
model,
|
|
@@ -214,6 +538,31 @@ def run_training(args: argparse.Namespace) -> None:
|
|
| 214 |
mode="reward_aware" if args.generator_mode == "reward_aware" else "heuristic",
|
| 215 |
deterministic=not args.non_deterministic_generator,
|
| 216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
training_args = GRPOConfig(
|
| 218 |
output_dir=args.output_dir,
|
| 219 |
learning_rate=args.learning_rate,
|
|
@@ -225,41 +574,68 @@ def run_training(args: argparse.Namespace) -> None:
|
|
| 225 |
max_steps=args.max_steps,
|
| 226 |
logging_steps=1,
|
| 227 |
bf16=args.bf16,
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
trainer = GRPOTrainer(
|
| 231 |
model=model,
|
| 232 |
-
reward_funcs=[build_reward_func(curriculum, controller)],
|
| 233 |
args=training_args,
|
| 234 |
train_dataset=build_dataset(args.dataset_size, controller, curriculum),
|
| 235 |
)
|
| 236 |
trainer.train()
|
|
|
|
| 237 |
model.save_pretrained(args.output_dir)
|
| 238 |
tokenizer.save_pretrained(args.output_dir)
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def build_parser() -> argparse.ArgumentParser:
|
| 242 |
parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
|
| 243 |
parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
|
| 244 |
-
parser.add_argument("--output-dir", default="
|
| 245 |
parser.add_argument("--dataset-size", type=int, default=200)
|
| 246 |
parser.add_argument("--max-steps", type=int, default=250)
|
| 247 |
parser.add_argument("--batch-size", type=int, default=1)
|
| 248 |
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
| 249 |
parser.add_argument("--num-generations", type=int, default=8)
|
| 250 |
parser.add_argument("--max-seq-length", type=int, default=2048)
|
| 251 |
-
parser.add_argument("--max-prompt-length", type=int, default=
|
| 252 |
parser.add_argument("--max-completion-length", type=int, default=512)
|
| 253 |
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 254 |
parser.add_argument("--lora-rank", type=int, default=16)
|
| 255 |
parser.add_argument("--lora-alpha", type=int, default=16)
|
| 256 |
parser.add_argument("--disable-4bit", action="store_true")
|
| 257 |
parser.add_argument("--bf16", action="store_true")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
parser.add_argument(
|
| 259 |
"--generator-mode",
|
| 260 |
choices=["heuristic", "reward_aware"],
|
| 261 |
-
default="
|
| 262 |
-
help="Use heuristic generation
|
| 263 |
)
|
| 264 |
parser.add_argument(
|
| 265 |
"--non-deterministic-generator",
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
+
import csv
|
| 5 |
import json
|
| 6 |
+
import math
|
| 7 |
from dataclasses import dataclass, field
|
| 8 |
+
from pathlib import Path
|
| 9 |
from typing import Any
|
| 10 |
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
|
| 14 |
+
from env.generator import DIFFICULTY_LABELS, GeneratorAgent
|
| 15 |
from models import AdaptAction
|
| 16 |
|
| 17 |
+
SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
|
| 18 |
+
Write only runnable Python code.
|
| 19 |
+
The program must read from stdin and print to stdout.
|
| 20 |
+
If feedback is present, repair your previous solution instead of starting from scratch.
|
| 21 |
+
Do not include markdown fences or explanations."""
|
| 22 |
+
|
| 23 |
|
| 24 |
def extract_code(completion: str) -> str:
|
| 25 |
text = completion.strip()
|
|
|
|
| 30 |
return text
|
| 31 |
|
| 32 |
|
| 33 |
+
def format_examples(problem: dict[str, Any]) -> str:
|
| 34 |
+
visible_cases = [test_case for test_case in problem.get("test_cases", []) if test_case.get("is_visible", False)]
|
| 35 |
+
if not visible_cases:
|
| 36 |
+
return problem["problem"]
|
| 37 |
+
|
| 38 |
+
chunks = []
|
| 39 |
+
for test_case in visible_cases:
|
| 40 |
+
chunks.append(f"Input:\n{test_case['input']}Expected Output:\n{test_case['output']}\n")
|
| 41 |
+
return f"{problem['problem']}\n\nExamples:\n" + "\n".join(chunks).rstrip()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_solver_prompt(payload: dict[str, Any]) -> str:
|
| 45 |
+
feedback = payload.get("feedback") or "No previous attempt yet."
|
| 46 |
+
return (
|
| 47 |
+
f"{SYSTEM_PROMPT}\n\n"
|
| 48 |
+
f"Problem ID: {payload['problem_id']}\n"
|
| 49 |
+
f"Problem Family: {payload['problem_type']}\n"
|
| 50 |
+
f"Difficulty: {payload['difficulty']}\n"
|
| 51 |
+
f"Attempt: {payload.get('attempt_number', 0)}/{payload.get('max_steps', MAX_STEPS_PER_EPISODE)}\n\n"
|
| 52 |
+
f"Problem:\n{payload['problem']}\n\n"
|
| 53 |
+
f"Input Format:\n{payload['input_format']}\n\n"
|
| 54 |
+
f"Constraints:\n{payload['constraints']}\n\n"
|
| 55 |
+
f"Feedback:\n{feedback}\n"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def build_prompt_from_problem(problem: dict[str, Any]) -> str:
|
| 60 |
+
payload = {
|
| 61 |
"problem_id": problem["problem_id"],
|
| 62 |
+
"problem_type": problem["problem_type"],
|
| 63 |
"difficulty": problem["difficulty_label"],
|
| 64 |
+
"attempt_number": 0,
|
| 65 |
+
"max_steps": MAX_STEPS_PER_EPISODE,
|
| 66 |
+
"problem": format_examples(problem),
|
| 67 |
"input_format": problem["input_format"],
|
| 68 |
"constraints": problem["constraints"],
|
| 69 |
+
"feedback": "No previous attempt yet. Solve the problem directly from the examples and constraints.",
|
| 70 |
}
|
| 71 |
+
return build_solver_prompt(payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
@dataclass
|
|
|
|
| 77 |
current_idx: int = 0
|
| 78 |
success_history: list[float] = field(default_factory=list)
|
| 79 |
window_size: int = 10
|
| 80 |
+
promote_threshold: float = 0.70
|
| 81 |
+
demote_threshold: float = 0.30
|
| 82 |
|
| 83 |
def current_difficulty(self) -> str:
|
| 84 |
return self.difficulties[self.current_idx]
|
| 85 |
|
| 86 |
+
def current_level(self) -> int:
|
| 87 |
+
return self.current_idx + 1
|
| 88 |
+
|
| 89 |
+
def update(self, episode_pass_rate: float) -> None:
|
| 90 |
+
self.success_history.append(float(episode_pass_rate))
|
| 91 |
if len(self.success_history) > self.window_size:
|
| 92 |
self.success_history.pop(0)
|
| 93 |
|
| 94 |
+
if len(self.success_history) < self.window_size:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
moving_average = sum(self.success_history) / len(self.success_history)
|
| 98 |
+
if moving_average >= self.promote_threshold and self.current_idx < len(self.difficulties) - 1:
|
| 99 |
self.current_idx += 1
|
| 100 |
self.success_history.clear()
|
| 101 |
print(
|
| 102 |
f"[curriculum] promoted to {self.current_difficulty()} "
|
| 103 |
+
f"(moving_pass_rate={moving_average:.2f})"
|
| 104 |
)
|
| 105 |
+
elif moving_average <= self.demote_threshold and self.current_idx > 0:
|
| 106 |
self.current_idx -= 1
|
| 107 |
self.success_history.clear()
|
| 108 |
print(
|
| 109 |
+
f"[curriculum] demoted to {self.current_difficulty()} "
|
| 110 |
+
f"(moving_pass_rate={moving_average:.2f})"
|
| 111 |
)
|
| 112 |
|
| 113 |
|
|
|
|
| 115 |
class GeneratorController:
|
| 116 |
mode: str = "heuristic"
|
| 117 |
deterministic: bool = True
|
| 118 |
+
temperature: float = 0.5
|
| 119 |
generator: GeneratorAgent = field(init=False)
|
| 120 |
history: dict[str, Any] = field(
|
| 121 |
default_factory=lambda: {
|
|
|
|
| 127 |
}
|
| 128 |
)
|
| 129 |
prompt_registry: dict[str, dict[str, Any]] = field(default_factory=dict)
|
| 130 |
+
family_productivity: dict[str, float] = field(default_factory=dict)
|
| 131 |
|
| 132 |
def __post_init__(self) -> None:
|
| 133 |
self.generator = GeneratorAgent(deterministic=self.deterministic)
|
| 134 |
+
if not self.family_productivity:
|
| 135 |
+
self.family_productivity = {
|
| 136 |
+
template.problem_type: 0.0 for template in self.generator.templates
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def family_names(self) -> list[str]:
|
| 141 |
+
return sorted(self.family_productivity)
|
| 142 |
+
|
| 143 |
+
def sample_problem(self, difficulty: str) -> dict[str, Any]:
|
| 144 |
+
family_weights = self.family_weights_for_difficulty(difficulty)
|
| 145 |
+
problem = self.generator.generate_problem(
|
| 146 |
+
difficulty_level=difficulty,
|
| 147 |
+
history=self.history,
|
| 148 |
+
family_weights=family_weights,
|
| 149 |
+
)
|
| 150 |
+
return problem
|
| 151 |
|
| 152 |
def create_rollout_problem(self, difficulty: str) -> tuple[str, dict[str, Any]]:
|
| 153 |
+
problem = self.sample_problem(difficulty)
|
| 154 |
+
prompt = build_prompt_from_problem(problem)
|
| 155 |
self.prompt_registry[prompt] = problem
|
| 156 |
return prompt, problem
|
| 157 |
|
| 158 |
def resolve_prompt(self, prompt: str) -> dict[str, Any]:
|
| 159 |
if prompt not in self.prompt_registry:
|
| 160 |
raise KeyError("Prompt was not registered with the generator controller.")
|
| 161 |
+
return self.prompt_registry.pop(prompt)
|
| 162 |
+
|
| 163 |
+
def family_weights_for_difficulty(self, difficulty: str) -> dict[str, float] | None:
|
| 164 |
+
if self.mode != "reward_aware":
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
eligible = [
|
| 168 |
+
template.problem_type
|
| 169 |
+
for template in self.generator.templates
|
| 170 |
+
if DIFFICULTY_LABELS[template.difficulty_tier] == difficulty
|
| 171 |
+
]
|
| 172 |
+
if not eligible:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
logits = [self.family_productivity.get(family, 0.0) / self.temperature for family in eligible]
|
| 176 |
+
max_logit = max(logits)
|
| 177 |
+
exp_values = [math.exp(logit - max_logit) for logit in logits]
|
| 178 |
+
return {family: value for family, value in zip(eligible, exp_values)}
|
| 179 |
+
|
| 180 |
+
def update(
|
| 181 |
+
self,
|
| 182 |
+
problem: dict[str, Any],
|
| 183 |
+
pass_rate: float,
|
| 184 |
+
generator_reward_signal: float,
|
| 185 |
+
*,
|
| 186 |
+
update_productivity: bool = True,
|
| 187 |
+
) -> None:
|
| 188 |
self.history["recent_pass_rates"].append(round(float(pass_rate), 4))
|
| 189 |
self.history["problem_types"].append(problem.get("problem_type", ""))
|
| 190 |
self.history["problem_signatures"].append(problem.get("problem_id", ""))
|
| 191 |
+
self.history["generator_rewards"].append(round(float(generator_reward_signal), 4))
|
|
|
|
|
|
|
|
|
|
| 192 |
self.history["episode_index"] = int(self.history.get("episode_index", 0)) + 1
|
| 193 |
|
| 194 |
+
if self.mode == "reward_aware" and update_productivity:
|
| 195 |
+
family = problem.get("problem_type", "")
|
| 196 |
+
current = float(self.family_productivity.get(family, 0.0))
|
| 197 |
+
updated = 0.9 * current + 0.1 * float(generator_reward_signal)
|
| 198 |
+
self.family_productivity[family] = round(updated, 6)
|
| 199 |
+
|
| 200 |
for key in ("recent_pass_rates", "problem_types", "problem_signatures", "generator_rewards"):
|
| 201 |
values = self.history[key]
|
| 202 |
+
if len(values) > 100:
|
| 203 |
+
del values[:-100]
|
| 204 |
|
| 205 |
def stats_snapshot(self) -> dict[str, Any]:
|
| 206 |
return {
|
|
|
|
| 209 |
"recent_pass_rates": list(self.history["recent_pass_rates"][-5:]),
|
| 210 |
"recent_problem_types": list(self.history["problem_types"][-5:]),
|
| 211 |
"recent_generator_rewards": list(self.history["generator_rewards"][-5:]),
|
| 212 |
+
"family_productivity": self.productivity_snapshot(),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def productivity_snapshot(self) -> dict[str, float]:
|
| 216 |
+
return {
|
| 217 |
+
family: round(float(value), 6)
|
| 218 |
+
for family, value in sorted(self.family_productivity.items())
|
| 219 |
}
|
| 220 |
|
| 221 |
|
|
|
|
| 234 |
return {"prompt": prompt}
|
| 235 |
|
| 236 |
|
| 237 |
+
@dataclass
|
| 238 |
+
class TrainingLogger:
|
| 239 |
+
output_dir: Path
|
| 240 |
+
family_names: list[str]
|
| 241 |
+
use_wandb: bool = True
|
| 242 |
+
wandb_project: str = "adapt-dsa-tutor"
|
| 243 |
+
wandb_run_name: str | None = None
|
| 244 |
+
rows: list[dict[str, Any]] = field(default_factory=list)
|
| 245 |
+
global_step: int = 0
|
| 246 |
+
_wandb_run: Any = field(default=None, init=False, repr=False)
|
| 247 |
+
|
| 248 |
+
def __post_init__(self) -> None:
|
| 249 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 250 |
+
if not self.use_wandb:
|
| 251 |
+
return
|
| 252 |
+
try:
|
| 253 |
+
import wandb
|
| 254 |
+
|
| 255 |
+
self._wandb_run = wandb.init(
|
| 256 |
+
project=self.wandb_project,
|
| 257 |
+
name=self.wandb_run_name,
|
| 258 |
+
config={"family_names": self.family_names},
|
| 259 |
+
reinit=True,
|
| 260 |
+
)
|
| 261 |
+
except Exception:
|
| 262 |
+
self._wandb_run = None
|
| 263 |
+
|
| 264 |
+
def log_event(
|
| 265 |
+
self,
|
| 266 |
+
*,
|
| 267 |
+
phase: str,
|
| 268 |
+
episode_reward: float,
|
| 269 |
+
pass_rate: float,
|
| 270 |
+
visible_pass_rate: float,
|
| 271 |
+
difficulty_tier: str,
|
| 272 |
+
problem_family: str,
|
| 273 |
+
curriculum_level: int,
|
| 274 |
+
execution_status: str,
|
| 275 |
+
attempt_number: int,
|
| 276 |
+
family_productivity: dict[str, float],
|
| 277 |
+
extra: dict[str, Any] | None = None,
|
| 278 |
+
) -> None:
|
| 279 |
+
row: dict[str, Any] = {
|
| 280 |
+
"step": self.global_step,
|
| 281 |
+
"phase": phase,
|
| 282 |
+
"episode_reward": round(float(episode_reward), 4),
|
| 283 |
+
"pass_rate": round(float(pass_rate), 4),
|
| 284 |
+
"visible_pass_rate": round(float(visible_pass_rate), 4),
|
| 285 |
+
"difficulty_tier": difficulty_tier,
|
| 286 |
+
"problem_family": problem_family,
|
| 287 |
+
"curriculum_level": curriculum_level,
|
| 288 |
+
"execution_status": execution_status,
|
| 289 |
+
"attempt_number": int(attempt_number),
|
| 290 |
+
}
|
| 291 |
+
for family in self.family_names:
|
| 292 |
+
row[f"family_productivity__{family}"] = round(float(family_productivity.get(family, 0.0)), 6)
|
| 293 |
+
if extra:
|
| 294 |
+
row.update(extra)
|
| 295 |
+
self.rows.append(row)
|
| 296 |
+
if self._wandb_run is not None:
|
| 297 |
+
self._wandb_run.log(row, step=self.global_step)
|
| 298 |
+
self.global_step += 1
|
| 299 |
+
|
| 300 |
+
def write_csv(self) -> Path:
|
| 301 |
+
output_path = self.output_dir / "reward_curve.csv"
|
| 302 |
+
fieldnames: list[str] = []
|
| 303 |
+
for row in self.rows:
|
| 304 |
+
for key in row:
|
| 305 |
+
if key not in fieldnames:
|
| 306 |
+
fieldnames.append(key)
|
| 307 |
+
with output_path.open("w", newline="", encoding="utf-8") as handle:
|
| 308 |
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
| 309 |
+
writer.writeheader()
|
| 310 |
+
writer.writerows(self.rows)
|
| 311 |
+
return output_path
|
| 312 |
+
|
| 313 |
+
def close(self) -> None:
|
| 314 |
+
if self._wandb_run is not None:
|
| 315 |
+
self._wandb_run.finish()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
|
| 319 |
+
return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def build_reward_func(
|
| 323 |
+
curriculum: CurriculumManager,
|
| 324 |
+
controller: GeneratorController,
|
| 325 |
+
logger: TrainingLogger,
|
| 326 |
+
):
|
| 327 |
def reward_func(prompts, completions, **kwargs) -> list[float]:
|
| 328 |
del kwargs
|
|
|
|
| 329 |
rewards: list[float] = []
|
|
|
|
| 330 |
|
| 331 |
for prompt, completion in zip(prompts, completions):
|
| 332 |
problem = controller.resolve_prompt(prompt)
|
| 333 |
+
env = AdaptEnvironment(generator=controller.generator, generator_mode=controller.mode)
|
| 334 |
env.reset(
|
| 335 |
difficulty=problem["difficulty_label"],
|
| 336 |
generated_problem=problem,
|
| 337 |
generator_mode=controller.mode,
|
| 338 |
+
session_id=env.session_id,
|
| 339 |
+
)
|
| 340 |
+
observation = env.step(
|
| 341 |
+
AdaptAction(
|
| 342 |
+
session_id=env.session_id,
|
| 343 |
+
code=extract_code(completion),
|
| 344 |
+
)
|
| 345 |
)
|
|
|
|
| 346 |
rewards.append(float(observation.reward))
|
| 347 |
+
controller.update(
|
| 348 |
+
problem=problem,
|
| 349 |
+
pass_rate=observation.pass_rate,
|
| 350 |
+
generator_reward_signal=observation.generator_reward_signal,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
)
|
| 352 |
+
curriculum.update(observation.pass_rate)
|
| 353 |
+
logger.log_event(
|
| 354 |
+
phase="train",
|
| 355 |
+
episode_reward=float(observation.reward),
|
| 356 |
+
pass_rate=float(observation.pass_rate),
|
| 357 |
+
visible_pass_rate=float(observation.visible_pass_rate),
|
| 358 |
+
difficulty_tier=problem["difficulty_label"],
|
| 359 |
+
problem_family=problem["problem_type"],
|
| 360 |
+
curriculum_level=curriculum.current_level(),
|
| 361 |
+
execution_status=observation.execution_status,
|
| 362 |
+
attempt_number=int(observation.attempt_number),
|
| 363 |
+
family_productivity=controller.productivity_snapshot(),
|
| 364 |
+
extra={
|
| 365 |
+
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 366 |
+
"problem_id": problem["problem_id"],
|
| 367 |
+
},
|
| 368 |
+
)
|
| 369 |
+
if controller.mode == "reward_aware" and controller.history["episode_index"] % 50 == 0:
|
| 370 |
+
print("[family_productivity]", json.dumps(controller.productivity_snapshot()))
|
| 371 |
|
| 372 |
return rewards
|
| 373 |
|
| 374 |
return reward_func
|
| 375 |
|
| 376 |
|
| 377 |
+
def render_prompt(tokenizer: Any, prompt: str) -> str:
|
| 378 |
+
messages = [
|
| 379 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 380 |
+
{"role": "user", "content": prompt},
|
| 381 |
+
]
|
| 382 |
+
if hasattr(tokenizer, "apply_chat_template"):
|
| 383 |
+
return tokenizer.apply_chat_template(
|
| 384 |
+
messages,
|
| 385 |
+
tokenize=False,
|
| 386 |
+
add_generation_prompt=True,
|
| 387 |
+
)
|
| 388 |
+
return f"{SYSTEM_PROMPT}\n\n{prompt}"
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def generate_completion(
|
| 392 |
+
model: Any,
|
| 393 |
+
tokenizer: Any,
|
| 394 |
+
prompt: str,
|
| 395 |
+
*,
|
| 396 |
+
max_new_tokens: int,
|
| 397 |
+
) -> str:
|
| 398 |
+
rendered = render_prompt(tokenizer, prompt)
|
| 399 |
+
inputs = tokenizer(rendered, return_tensors="pt")
|
| 400 |
+
device = getattr(model, "device", None)
|
| 401 |
+
if device is None:
|
| 402 |
+
device = next(model.parameters()).device
|
| 403 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
| 404 |
+
with torch.no_grad():
|
| 405 |
+
outputs = model.generate(
|
| 406 |
+
**inputs,
|
| 407 |
+
max_new_tokens=max_new_tokens,
|
| 408 |
+
do_sample=False,
|
| 409 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 410 |
+
)
|
| 411 |
+
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
| 412 |
+
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def run_policy_evaluation(
|
| 416 |
+
*,
|
| 417 |
+
model: Any,
|
| 418 |
+
tokenizer: Any,
|
| 419 |
+
generator_mode: str,
|
| 420 |
+
deterministic_generator: bool,
|
| 421 |
+
episodes: int,
|
| 422 |
+
logger: TrainingLogger,
|
| 423 |
+
phase: str,
|
| 424 |
+
max_new_tokens: int,
|
| 425 |
+
) -> dict[str, Any]:
|
| 426 |
+
controller = GeneratorController(
|
| 427 |
+
mode=generator_mode,
|
| 428 |
+
deterministic=deterministic_generator,
|
| 429 |
+
)
|
| 430 |
+
schedule = ["easy"] * (episodes // 3 + (1 if episodes % 3 > 0 else 0))
|
| 431 |
+
schedule += ["medium"] * (episodes // 3 + (1 if episodes % 3 > 1 else 0))
|
| 432 |
+
schedule += ["hard"] * (episodes // 3)
|
| 433 |
+
schedule = schedule[:episodes]
|
| 434 |
+
|
| 435 |
+
tier_records: dict[str, list[float]] = {"easy": [], "medium": [], "hard": []}
|
| 436 |
+
|
| 437 |
+
for difficulty in schedule:
|
| 438 |
+
problem = controller.sample_problem(difficulty)
|
| 439 |
+
env = AdaptEnvironment(generator=controller.generator, generator_mode=generator_mode)
|
| 440 |
+
observation = env.reset(
|
| 441 |
+
difficulty=difficulty,
|
| 442 |
+
generated_problem=problem,
|
| 443 |
+
session_id=env.session_id,
|
| 444 |
+
generator_mode=generator_mode,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
for _ in range(MAX_STEPS_PER_EPISODE):
|
| 448 |
+
prompt = build_solver_prompt(observation.model_dump())
|
| 449 |
+
completion = generate_completion(
|
| 450 |
+
model=model,
|
| 451 |
+
tokenizer=tokenizer,
|
| 452 |
+
prompt=prompt,
|
| 453 |
+
max_new_tokens=max_new_tokens,
|
| 454 |
+
)
|
| 455 |
+
observation = env.step(
|
| 456 |
+
AdaptAction(
|
| 457 |
+
session_id=env.session_id,
|
| 458 |
+
code=extract_code(completion),
|
| 459 |
+
)
|
| 460 |
+
)
|
| 461 |
+
if observation.done:
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
controller.update(
|
| 465 |
+
problem=problem,
|
| 466 |
+
pass_rate=observation.pass_rate,
|
| 467 |
+
generator_reward_signal=observation.generator_reward_signal,
|
| 468 |
+
update_productivity=False,
|
| 469 |
+
)
|
| 470 |
+
tier_records[difficulty].append(float(observation.pass_rate))
|
| 471 |
+
logger.log_event(
|
| 472 |
+
phase=phase,
|
| 473 |
+
episode_reward=float(observation.reward),
|
| 474 |
+
pass_rate=float(observation.pass_rate),
|
| 475 |
+
visible_pass_rate=float(observation.visible_pass_rate),
|
| 476 |
+
difficulty_tier=difficulty,
|
| 477 |
+
problem_family=problem["problem_type"],
|
| 478 |
+
curriculum_level={"easy": 1, "medium": 2, "hard": 3}[difficulty],
|
| 479 |
+
execution_status=observation.execution_status,
|
| 480 |
+
attempt_number=int(observation.attempt_number),
|
| 481 |
+
family_productivity=controller.productivity_snapshot(),
|
| 482 |
+
extra={
|
| 483 |
+
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 484 |
+
"problem_id": problem["problem_id"],
|
| 485 |
+
},
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
summary = {
|
| 489 |
+
tier: (sum(values) / len(values) if values else 0.0)
|
| 490 |
+
for tier, values in tier_records.items()
|
| 491 |
+
}
|
| 492 |
+
summary["overall"] = (
|
| 493 |
+
sum(value for values in tier_records.values() for value in values) / episodes if episodes else 0.0
|
| 494 |
+
)
|
| 495 |
+
return summary
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def print_evaluation_summary(baseline: dict[str, Any], trained: dict[str, Any]) -> None:
|
| 499 |
+
print("\nBaseline vs trained pass rate summary")
|
| 500 |
+
print(f"{'Difficulty':<12} {'Baseline':>10} {'Trained':>10}")
|
| 501 |
+
print("-" * 34)
|
| 502 |
+
for tier in ("easy", "medium", "hard", "overall"):
|
| 503 |
+
print(f"{tier:<12} {baseline.get(tier, 0.0):>10.3f} {trained.get(tier, 0.0):>10.3f}")
|
| 504 |
|
| 505 |
|
| 506 |
def run_training(args: argparse.Namespace) -> None:
|
|
|
|
| 512 |
"Training dependencies are missing. Install `trl` and `unsloth` before running GRPO training."
|
| 513 |
) from exc
|
| 514 |
|
| 515 |
+
output_dir = Path(args.output_dir)
|
| 516 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 517 |
+
|
| 518 |
PatchFastRL("GRPO", FastLanguageModel)
|
| 519 |
|
| 520 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
|
| 522 |
max_seq_length=args.max_seq_length,
|
| 523 |
load_in_4bit=not args.disable_4bit,
|
| 524 |
)
|
| 525 |
+
if tokenizer.pad_token is None:
|
| 526 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 527 |
|
| 528 |
model = FastLanguageModel.get_peft_model(
|
| 529 |
model,
|
|
|
|
| 538 |
mode="reward_aware" if args.generator_mode == "reward_aware" else "heuristic",
|
| 539 |
deterministic=not args.non_deterministic_generator,
|
| 540 |
)
|
| 541 |
+
logger = TrainingLogger(
|
| 542 |
+
output_dir=output_dir,
|
| 543 |
+
family_names=controller.family_names,
|
| 544 |
+
use_wandb=not args.disable_wandb,
|
| 545 |
+
wandb_project=args.wandb_project,
|
| 546 |
+
wandb_run_name=args.wandb_run_name,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 550 |
+
trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 551 |
+
|
| 552 |
+
if args.baseline_eval:
|
| 553 |
+
FastLanguageModel.for_inference(model)
|
| 554 |
+
baseline_summary = run_policy_evaluation(
|
| 555 |
+
model=model,
|
| 556 |
+
tokenizer=tokenizer,
|
| 557 |
+
generator_mode=controller.mode,
|
| 558 |
+
deterministic_generator=not args.non_deterministic_generator,
|
| 559 |
+
episodes=args.evaluation_episodes,
|
| 560 |
+
logger=logger,
|
| 561 |
+
phase="baseline_eval",
|
| 562 |
+
max_new_tokens=args.eval_max_new_tokens,
|
| 563 |
+
)
|
| 564 |
+
print(f"[baseline_eval] {json.dumps(baseline_summary)}")
|
| 565 |
+
|
| 566 |
training_args = GRPOConfig(
|
| 567 |
output_dir=args.output_dir,
|
| 568 |
learning_rate=args.learning_rate,
|
|
|
|
| 574 |
max_steps=args.max_steps,
|
| 575 |
logging_steps=1,
|
| 576 |
bf16=args.bf16,
|
| 577 |
+
report_to=[],
|
| 578 |
)
|
| 579 |
|
| 580 |
trainer = GRPOTrainer(
|
| 581 |
model=model,
|
| 582 |
+
reward_funcs=[build_reward_func(curriculum, controller, logger)],
|
| 583 |
args=training_args,
|
| 584 |
train_dataset=build_dataset(args.dataset_size, controller, curriculum),
|
| 585 |
)
|
| 586 |
trainer.train()
|
| 587 |
+
|
| 588 |
model.save_pretrained(args.output_dir)
|
| 589 |
tokenizer.save_pretrained(args.output_dir)
|
| 590 |
|
| 591 |
+
if args.baseline_eval:
|
| 592 |
+
FastLanguageModel.for_inference(model)
|
| 593 |
+
trained_summary = run_policy_evaluation(
|
| 594 |
+
model=model,
|
| 595 |
+
tokenizer=tokenizer,
|
| 596 |
+
generator_mode=controller.mode,
|
| 597 |
+
deterministic_generator=not args.non_deterministic_generator,
|
| 598 |
+
episodes=args.evaluation_episodes,
|
| 599 |
+
logger=logger,
|
| 600 |
+
phase="trained_eval",
|
| 601 |
+
max_new_tokens=args.eval_max_new_tokens,
|
| 602 |
+
)
|
| 603 |
+
print(f"[trained_eval] {json.dumps(trained_summary)}")
|
| 604 |
+
print_evaluation_summary(baseline_summary, trained_summary)
|
| 605 |
+
|
| 606 |
+
csv_path = logger.write_csv()
|
| 607 |
+
logger.close()
|
| 608 |
+
print(f"[artifacts] reward curve CSV written to {csv_path}")
|
| 609 |
+
|
| 610 |
|
| 611 |
def build_parser() -> argparse.ArgumentParser:
|
| 612 |
parser = argparse.ArgumentParser(description="GRPO training entrypoint for the ADAPT DSA environment.")
|
| 613 |
parser.add_argument("--model-name", default="unsloth/Llama-3.2-3B-Instruct")
|
| 614 |
+
parser.add_argument("--output-dir", default="outputs_v3")
|
| 615 |
parser.add_argument("--dataset-size", type=int, default=200)
|
| 616 |
parser.add_argument("--max-steps", type=int, default=250)
|
| 617 |
parser.add_argument("--batch-size", type=int, default=1)
|
| 618 |
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
| 619 |
parser.add_argument("--num-generations", type=int, default=8)
|
| 620 |
parser.add_argument("--max-seq-length", type=int, default=2048)
|
| 621 |
+
parser.add_argument("--max-prompt-length", type=int, default=1024)
|
| 622 |
parser.add_argument("--max-completion-length", type=int, default=512)
|
| 623 |
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 624 |
parser.add_argument("--lora-rank", type=int, default=16)
|
| 625 |
parser.add_argument("--lora-alpha", type=int, default=16)
|
| 626 |
parser.add_argument("--disable-4bit", action="store_true")
|
| 627 |
parser.add_argument("--bf16", action="store_true")
|
| 628 |
+
parser.add_argument("--baseline-eval", action="store_true")
|
| 629 |
+
parser.add_argument("--evaluation-episodes", type=int, default=20)
|
| 630 |
+
parser.add_argument("--eval-max-new-tokens", type=int, default=512)
|
| 631 |
+
parser.add_argument("--disable-wandb", action="store_true")
|
| 632 |
+
parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
|
| 633 |
+
parser.add_argument("--wandb-run-name", default=None)
|
| 634 |
parser.add_argument(
|
| 635 |
"--generator-mode",
|
| 636 |
choices=["heuristic", "reward_aware"],
|
| 637 |
+
default="reward_aware",
|
| 638 |
+
help="Use heuristic generation or reward-aware family weighting.",
|
| 639 |
)
|
| 640 |
parser.add_argument(
|
| 641 |
"--non-deterministic-generator",
|
verifier/metrics.py
CHANGED
|
@@ -3,30 +3,54 @@ from __future__ import annotations
|
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
total = len(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
passed = sum(1 for result in results if result["passed"])
|
|
|
|
| 9 |
timeout_count = sum(1 for result in results if result["status"] == "timeout")
|
| 10 |
runtime_error_count = sum(1 for result in results if result["status"] == "runtime_error")
|
| 11 |
invalid_output_count = sum(1 for result in results if result["status"] == "invalid_output_format")
|
| 12 |
wrong_answer_count = sum(1 for result in results if result["status"] == "wrong_answer")
|
| 13 |
format_ok_count = sum(1 for result in results if result.get("format_ok", False))
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
format_compliance = format_ok_count / total if total else 0.0
|
| 17 |
-
timeout_rate = timeout_count / total if total else 0.0
|
| 18 |
-
runtime_error_rate = runtime_error_count / total if total else 0.0
|
| 19 |
-
invalid_output_rate = invalid_output_count / total if total else 0.0
|
| 20 |
-
|
| 21 |
-
reward_components = {
|
| 22 |
-
"correctness": 0.8 * pass_rate,
|
| 23 |
-
"format": 0.1 * format_compliance,
|
| 24 |
-
"execution": 0.1 if timeout_count == 0 and runtime_error_count == 0 else 0.0,
|
| 25 |
-
"timeout_penalty": -0.2 * timeout_rate,
|
| 26 |
-
"runtime_penalty": -0.1 * runtime_error_rate,
|
| 27 |
-
"invalid_output_penalty": -0.1 * invalid_output_rate,
|
| 28 |
-
}
|
| 29 |
-
reward = max(0.0, min(1.0, sum(reward_components.values())))
|
| 30 |
|
| 31 |
if timeout_count:
|
| 32 |
execution_status = "timeout"
|
|
@@ -39,10 +63,23 @@ def compute_pass_rate(results: list[dict[str, Any]]) -> tuple[float, dict[str, A
|
|
| 39 |
else:
|
| 40 |
execution_status = "completed"
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return reward, {
|
| 43 |
"passed": passed,
|
| 44 |
"total": total,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"pass_rate": round(pass_rate, 4),
|
|
|
|
|
|
|
| 46 |
"timeout_count": timeout_count,
|
| 47 |
"runtime_error_count": runtime_error_count,
|
| 48 |
"invalid_output_count": invalid_output_count,
|
|
@@ -50,6 +87,8 @@ def compute_pass_rate(results: list[dict[str, Any]]) -> tuple[float, dict[str, A
|
|
| 50 |
"format_compliance": round(format_compliance, 4),
|
| 51 |
"execution_status": execution_status,
|
| 52 |
"reward_components": {
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
},
|
| 55 |
}
|
|
|
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
|
| 6 |
+
def compute_reward(
|
| 7 |
+
pass_rate: float,
|
| 8 |
+
step_number: int,
|
| 9 |
+
execution_status: str,
|
| 10 |
+
format_compliance: float,
|
| 11 |
+
) -> float:
|
| 12 |
+
"""
|
| 13 |
+
Clean, interpretable reward signal for GRPO training.
|
| 14 |
+
"""
|
| 15 |
+
del format_compliance
|
| 16 |
+
|
| 17 |
+
step_discount = 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70)
|
| 18 |
+
correctness = pass_rate
|
| 19 |
+
|
| 20 |
+
if execution_status == "timeout":
|
| 21 |
+
return 0.0
|
| 22 |
+
if execution_status == "syntax_error":
|
| 23 |
+
return 0.0
|
| 24 |
+
|
| 25 |
+
reward = correctness * step_discount
|
| 26 |
+
return round(min(max(reward, 0.0), 1.0), 4)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def compute_pass_rate(
|
| 30 |
+
results: list[dict[str, Any]],
|
| 31 |
+
step_number: int = 1,
|
| 32 |
+
) -> tuple[float, dict[str, Any]]:
|
| 33 |
total = len(results)
|
| 34 |
+
hidden_results = [result for result in results if result.get("visibility") == "hidden"]
|
| 35 |
+
visible_results = [result for result in results if result.get("visibility") == "visible"]
|
| 36 |
+
|
| 37 |
+
hidden_total = len(hidden_results)
|
| 38 |
+
visible_total = len(visible_results)
|
| 39 |
+
|
| 40 |
+
hidden_passed = sum(1 for result in hidden_results if result["passed"])
|
| 41 |
+
visible_passed = sum(1 for result in visible_results if result["passed"])
|
| 42 |
passed = sum(1 for result in results if result["passed"])
|
| 43 |
+
|
| 44 |
timeout_count = sum(1 for result in results if result["status"] == "timeout")
|
| 45 |
runtime_error_count = sum(1 for result in results if result["status"] == "runtime_error")
|
| 46 |
invalid_output_count = sum(1 for result in results if result["status"] == "invalid_output_format")
|
| 47 |
wrong_answer_count = sum(1 for result in results if result["status"] == "wrong_answer")
|
| 48 |
format_ok_count = sum(1 for result in results if result.get("format_ok", False))
|
| 49 |
|
| 50 |
+
hidden_pass_rate = hidden_passed / hidden_total if hidden_total else 0.0
|
| 51 |
+
visible_pass_rate = visible_passed / visible_total if visible_total else 0.0
|
| 52 |
+
pass_rate = hidden_pass_rate if hidden_total else (passed / total if total else 0.0)
|
| 53 |
format_compliance = format_ok_count / total if total else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
if timeout_count:
|
| 56 |
execution_status = "timeout"
|
|
|
|
| 63 |
else:
|
| 64 |
execution_status = "completed"
|
| 65 |
|
| 66 |
+
reward = compute_reward(
|
| 67 |
+
pass_rate=pass_rate,
|
| 68 |
+
step_number=step_number,
|
| 69 |
+
execution_status=execution_status,
|
| 70 |
+
format_compliance=format_compliance,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
return reward, {
|
| 74 |
"passed": passed,
|
| 75 |
"total": total,
|
| 76 |
+
"hidden_passed": hidden_passed,
|
| 77 |
+
"hidden_total": hidden_total,
|
| 78 |
+
"visible_passed": visible_passed,
|
| 79 |
+
"visible_total": visible_total,
|
| 80 |
"pass_rate": round(pass_rate, 4),
|
| 81 |
+
"hidden_pass_rate": round(hidden_pass_rate, 4),
|
| 82 |
+
"visible_pass_rate": round(visible_pass_rate, 4),
|
| 83 |
"timeout_count": timeout_count,
|
| 84 |
"runtime_error_count": runtime_error_count,
|
| 85 |
"invalid_output_count": invalid_output_count,
|
|
|
|
| 87 |
"format_compliance": round(format_compliance, 4),
|
| 88 |
"execution_status": execution_status,
|
| 89 |
"reward_components": {
|
| 90 |
+
"correctness": round(float(pass_rate), 4),
|
| 91 |
+
"step_discount": 1.0 if step_number == 1 else (0.85 if step_number == 2 else 0.70),
|
| 92 |
+
"reward": reward,
|
| 93 |
},
|
| 94 |
}
|