eyad-silx commited on
Commit
3604754
·
verified ·
1 Parent(s): febab99

Upload backprop_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backprop_test.py +100 -0
backprop_test.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test Backprop NEAT on 2D classification tasks."""
2
+
3
+ import os
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import networkx as nx
9
+ from neat.datasets import (generate_xor_data, generate_circle_data,
10
+ generate_spiral_data, plot_dataset)
11
+ from neat.backprop_neat import BackpropNEAT
12
+
13
+ def train_and_visualize(neat: BackpropNEAT, x: jnp.ndarray, y: jnp.ndarray,
14
+ dataset_name: str, viz_dir: str = 'visualizations'):
15
+ """Train network and save visualizations."""
16
+ os.makedirs(viz_dir, exist_ok=True)
17
+
18
+ # Plot dataset
19
+ plot_dataset(x, y, f'{dataset_name} Dataset')
20
+ plt.savefig(os.path.join(viz_dir, f'{dataset_name}_dataset.png'))
21
+ plt.close()
22
+
23
+ # Training loop
24
+ n_generations = 50
25
+ n_epochs = 100
26
+
27
+ for gen in range(n_generations):
28
+ # Train networks with backprop
29
+ neat.train_networks(x, y, n_epochs=n_epochs)
30
+
31
+ # Evaluate fitness
32
+ neat.evaluate_fitness(x, y)
33
+
34
+ # Get best network
35
+ best_network = max(neat.population, key=lambda n: n.fitness)
36
+
37
+ # Save visualizations every 10 generations
38
+ if gen % 10 == 0:
39
+ gen_dir = os.path.join(viz_dir, f'gen_{gen:03d}')
40
+ os.makedirs(gen_dir, exist_ok=True)
41
+
42
+ # Visualize network architecture
43
+ best_network.visualize(
44
+ save_path=os.path.join(gen_dir, f'{dataset_name}_network.png'))
45
+
46
+ # Plot decision boundary
47
+ plt.figure(figsize=(8, 8))
48
+
49
+ # Create grid of points
50
+ xx, yy = jnp.meshgrid(jnp.linspace(-1, 1, 100),
51
+ jnp.linspace(-1, 1, 100))
52
+ grid_points = jnp.stack([xx.ravel(), yy.ravel()], axis=1)
53
+
54
+ # Get predictions
55
+ predictions = jnp.array([best_network.forward(p)[0]
56
+ for p in grid_points])
57
+ predictions = predictions.reshape(xx.shape)
58
+
59
+ # Plot decision boundary
60
+ plt.contourf(xx, yy, predictions, alpha=0.4,
61
+ levels=jnp.linspace(0, 1, 20))
62
+ plot_dataset(x, y, f'{dataset_name} - Generation {gen}')
63
+ plt.savefig(os.path.join(gen_dir,
64
+ f'{dataset_name}_decision_boundary.png'))
65
+ plt.close()
66
+
67
+ # Evolve population
68
+ neat.evolve_population()
69
+
70
+ print(f'Generation {gen}: Best Fitness = {best_network.fitness:.4f}')
71
+
72
+ def main():
73
+ """Run experiments on different datasets."""
74
+ # Parameters
75
+ n_points = 50 # Points per quadrant/class
76
+ noise_level = 0.1
77
+ population_size = 50
78
+
79
+ # Test on different datasets
80
+ datasets = [
81
+ ('XOR', generate_xor_data),
82
+ ('Circle', generate_circle_data),
83
+ ('Spiral', generate_spiral_data)
84
+ ]
85
+
86
+ for name, generator in datasets:
87
+ print(f'\nTraining on {name} dataset:')
88
+
89
+ # Generate dataset
90
+ x, y = generator(n_points, noise_level)
91
+
92
+ # Create and train NEAT
93
+ neat = BackpropNEAT(n_inputs=2, n_outputs=1,
94
+ population_size=population_size)
95
+
96
+ # Train and visualize
97
+ train_and_visualize(neat, x, y, name)
98
+
99
+ if __name__ == '__main__':
100
+ main()