eyad-silx commited on
Commit
2131369
·
verified ·
1 Parent(s): 8e3abc5

Upload neat\network.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//network.py +452 -0
neat//network.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Neural network implementation for BackpropNEAT."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ from .genome import Genome
8
+ import copy
9
+ import random
10
+
11
+ class Network:
12
+ """Neural network for NEAT implementation.
13
+ Implements a strictly feed-forward network following original NEAT principles:
14
+ 1. Start minimal - direct input-output connections only
15
+ 2. Complexify gradually through structural mutations
16
+ 3. Protect innovation through speciation
17
+ 4. No recurrent connections (as per requirements)
18
+ """
19
+ def __init__(self, genome: Genome):
20
+ """Initialize network from genome."""
21
+ # Store genome and sizes
22
+ self.genome = genome
23
+
24
+ # Verify genome sizes match volleyball requirements
25
+ if genome.input_size != 12 or genome.output_size != 3:
26
+ print(f"Warning: Genome size mismatch. Expected 12 inputs, 3 outputs. Got {genome.input_size} inputs, {genome.output_size} outputs")
27
+ genome.input_size = 12
28
+ genome.output_size = 3
29
+
30
+ self.input_size = 12 # Fixed for volleyball
31
+ self.output_size = 3 # Fixed for volleyball
32
+
33
+ # Deep copy to avoid shared references
34
+ self.node_genes = {}
35
+ self.connection_genes = []
36
+
37
+ # Create input nodes (0 to 11)
38
+ for i in range(12):
39
+ self.node_genes[i] = NodeGene(i, 'input', 'linear')
40
+
41
+ # Create bias node (12)
42
+ self.node_genes[12] = NodeGene(12, 'bias', 'linear')
43
+
44
+ # Create output nodes (13, 14, 15)
45
+ for i in range(3):
46
+ node_id = 13 + i
47
+ self.node_genes[node_id] = NodeGene(node_id, 'output', 'sigmoid')
48
+
49
+ # Connect to bias with appropriate weight based on action type
50
+ if i < 2: # Left/Right actions: encourage movement
51
+ self.connection_genes.append(
52
+ ConnectionGene(12, node_id, random.uniform(0.0, 1.0), True)
53
+ )
54
+ else: # Jump action: neutral bias
55
+ self.connection_genes.append(
56
+ ConnectionGene(12, node_id, random.uniform(-0.5, 0.5), True)
57
+ )
58
+
59
+ # Connect to relevant inputs with larger weights
60
+ if i == 0: # Left action: connect to ball x position and velocity
61
+ self.connection_genes.append(
62
+ ConnectionGene(0, node_id, random.uniform(0.5, 1.5), True) # ball x
63
+ )
64
+ self.connection_genes.append(
65
+ ConnectionGene(2, node_id, random.uniform(0.5, 1.5), True) # ball vx
66
+ )
67
+ elif i == 1: # Right action: connect to ball x position and velocity
68
+ self.connection_genes.append(
69
+ ConnectionGene(0, node_id, random.uniform(-1.5, -0.5), True) # ball x
70
+ )
71
+ self.connection_genes.append(
72
+ ConnectionGene(2, node_id, random.uniform(-1.5, -0.5), True) # ball vx
73
+ )
74
+ else: # Jump action: connect to ball y position and velocity
75
+ self.connection_genes.append(
76
+ ConnectionGene(1, node_id, random.uniform(-1.5, -0.5), True) # ball y
77
+ )
78
+ self.connection_genes.append(
79
+ ConnectionGene(3, node_id, random.uniform(-1.0, 0.0), True) # ball vy
80
+ )
81
+
82
+ # Copy existing nodes (if any)
83
+ for node_id, node in genome.node_genes.items():
84
+ if node_id not in self.node_genes: # Skip I/O nodes
85
+ self.node_genes[node_id] = NodeGene(
86
+ node_id,
87
+ node.node_type,
88
+ node.activation
89
+ )
90
+
91
+ # Copy connections
92
+ if genome.connection_genes:
93
+ # Clear initial connections if genome has its own
94
+ self.connection_genes = []
95
+ for conn in genome.connection_genes:
96
+ # Verify connection nodes exist
97
+ if conn.source not in self.node_genes or conn.target not in self.node_genes:
98
+ print(f"Warning: Connection {conn.source}->{conn.target} references missing nodes")
99
+ continue
100
+ self.connection_genes.append(ConnectionGene(
101
+ conn.source,
102
+ conn.target,
103
+ conn.weight,
104
+ conn.enabled
105
+ ))
106
+
107
+ # Verify output connections (13, 14, 15)
108
+ for output_id in [13, 14, 15]:
109
+ has_connection = False
110
+ for conn in self.connection_genes:
111
+ if conn.enabled and conn.target == output_id:
112
+ has_connection = True
113
+ break
114
+
115
+ if not has_connection:
116
+ print(f"Adding missing connections for output {output_id}")
117
+ # Connect to bias
118
+ self.connection_genes.append(
119
+ ConnectionGene(12, output_id, random.uniform(-1.0, 1.0), True)
120
+ )
121
+ # Connect to random input
122
+ input_id = random.randint(0, 11)
123
+ self.connection_genes.append(
124
+ ConnectionGene(input_id, output_id, random.uniform(-1.0, 1.0), True)
125
+ )
126
+
127
+ # Build evaluation order
128
+ self.node_evals = {}
129
+ self._build_feed_forward_order()
130
+
131
+ # Verify all outputs are properly connected
132
+ self._verify_outputs()
133
+
134
+ def _verify_outputs(self):
135
+ """Verify all outputs have valid connections and evaluations."""
136
+ output_ids = {13, 14, 15} # Fixed output IDs
137
+
138
+ # Check node evaluations
139
+ for output_id in output_ids:
140
+ if output_id not in self.node_evals:
141
+ print(f"Adding missing evaluation for output {output_id}")
142
+ bias_id = 12
143
+ self.node_evals[output_id] = {
144
+ 'inputs': [bias_id],
145
+ 'weights': [1.0],
146
+ 'activation': 'sigmoid'
147
+ }
148
+ # Add connection if needed
149
+ if not any(c.target == output_id and c.enabled for c in self.connection_genes):
150
+ self.connection_genes.append(
151
+ ConnectionGene(bias_id, output_id, 1.0, True)
152
+ )
153
+
154
+ def _create_minimal_connections(self):
155
+ """Create minimal initial connections for a new network."""
156
+ bias_id = 12
157
+ output_start = bias_id + 1
158
+
159
+ # Connect each output to bias and one random input
160
+ for i in range(self.output_size):
161
+ output_id = output_start + i
162
+
163
+ # Connect to bias
164
+ self.connection_genes.append(ConnectionGene(
165
+ bias_id, output_id,
166
+ random.uniform(-1.0, 1.0),
167
+ True
168
+ ))
169
+
170
+ # Connect to random input
171
+ input_id = random.randint(0, self.input_size - 1)
172
+ self.connection_genes.append(ConnectionGene(
173
+ input_id, output_id,
174
+ random.uniform(-1.0, 1.0),
175
+ True
176
+ ))
177
+
178
+ def _build_feed_forward_order(self):
179
+ """Build evaluation order ensuring feed-forward only topology."""
180
+ try:
181
+ # Fixed node sets for volleyball
182
+ input_nodes = set(range(12)) # 0-11
183
+ bias_node = {12} # Bias node
184
+ output_nodes = {13, 14, 15} # Output nodes
185
+
186
+ # Create adjacency lists
187
+ connections = {}
188
+ for conn in self.connection_genes:
189
+ if not conn.enabled:
190
+ continue
191
+ if conn.source not in connections:
192
+ connections[conn.source] = []
193
+ connections[conn.source].append(conn.target)
194
+
195
+ # Start with inputs and bias evaluated
196
+ evaluated = input_nodes | bias_node
197
+ eval_order = []
198
+
199
+ # Helper function to check if a node can be evaluated
200
+ def can_evaluate(node_id):
201
+ if node_id in connections:
202
+ return all(dep in evaluated for dep in connections[node_id])
203
+ return True
204
+
205
+ # Keep trying to evaluate nodes until we can't anymore
206
+ while True:
207
+ ready_nodes = set()
208
+ for node_id in self.node_genes:
209
+ if node_id not in evaluated and can_evaluate(node_id):
210
+ ready_nodes.add(node_id)
211
+
212
+ if not ready_nodes:
213
+ break
214
+
215
+ # Add nodes to evaluation order
216
+ for node_id in sorted(ready_nodes):
217
+ incoming = []
218
+ incoming_weights = []
219
+ for conn in self.connection_genes:
220
+ if conn.enabled and conn.target == node_id:
221
+ incoming.append(conn.source)
222
+ incoming_weights.append(conn.weight)
223
+
224
+ if incoming: # Only add if node has inputs
225
+ self.node_evals[node_id] = {
226
+ 'inputs': incoming,
227
+ 'weights': incoming_weights,
228
+ 'activation': self.node_genes[node_id].activation
229
+ }
230
+ eval_order.append(node_id)
231
+
232
+ evaluated.add(node_id)
233
+
234
+ # Ensure all outputs have evaluations
235
+ for output_id in output_nodes:
236
+ if output_id not in self.node_evals:
237
+ print(f"Adding default evaluation for output {output_id}")
238
+ # Connect to bias by default
239
+ self.node_evals[output_id] = {
240
+ 'inputs': [12], # Bias node
241
+ 'weights': [1.0],
242
+ 'activation': 'sigmoid'
243
+ }
244
+ # Add connection if needed
245
+ if not any(c.target == output_id and c.enabled for c in self.connection_genes):
246
+ self.connection_genes.append(
247
+ ConnectionGene(12, output_id, 1.0, True)
248
+ )
249
+
250
+ except Exception as e:
251
+ print(f"Error in feed-forward build: {e}")
252
+ # Create minimal fallback evaluations
253
+ self.node_evals = {}
254
+ for i in range(3): # 3 outputs
255
+ output_id = 13 + i
256
+ self.node_evals[output_id] = {
257
+ 'inputs': [12], # Bias node
258
+ 'weights': [1.0],
259
+ 'activation': 'sigmoid'
260
+ }
261
+
262
+ def forward(self, inputs: jnp.ndarray) -> jnp.ndarray:
263
+ """Forward pass through the network."""
264
+ try:
265
+ # Only use first 8 inputs like original network
266
+ inputs = inputs[:8]
267
+
268
+ # Handle input shape
269
+ original_shape = inputs.shape
270
+ if len(inputs.shape) == 1:
271
+ inputs = inputs.reshape(1, -1)
272
+ batch_size = inputs.shape[0]
273
+
274
+ # Get max node ID for activation array
275
+ max_node_id = max(node.id for node in self.node_genes.values())
276
+
277
+ # Initialize activations array
278
+ activations = jnp.zeros((batch_size, max_node_id + 1))
279
+
280
+ # Set input values (0-7)
281
+ for i in range(8):
282
+ if i < len(inputs):
283
+ activations = activations.at[:, i].set(inputs[:, i])
284
+ else:
285
+ activations = activations.at[:, i].set(0.0)
286
+
287
+ # Initialize recurrent nodes (8-11) with previous outputs
288
+ # For now just use zeros, in the future we could store previous outputs
289
+ for i in range(8, 12):
290
+ activations = activations.at[:, i].set(0.0)
291
+
292
+ # Evaluate nodes in order (hidden then output)
293
+ for node_id, eval_info in self.node_evals.items():
294
+ try:
295
+ # Skip input and recurrent nodes
296
+ if node_id < 12:
297
+ continue
298
+
299
+ # Get weighted sum of inputs
300
+ act = jnp.zeros(batch_size)
301
+ for conn_source, conn_weight in zip(eval_info['inputs'], eval_info['weights']):
302
+ act += activations[:, conn_source] * conn_weight
303
+
304
+ # Apply activation function
305
+ if eval_info['activation'] == 'tanh':
306
+ act = jnp.tanh(act)
307
+ elif eval_info['activation'] == 'sigmoid':
308
+ act = jax.nn.sigmoid(act)
309
+ elif eval_info['activation'] == 'relu':
310
+ act = jax.nn.relu(act)
311
+
312
+ # Apply threshold like original network for output nodes
313
+ if node_id >= 20: # Output nodes
314
+ act = jnp.where(act > 0.75, 1.0, 0.0)
315
+
316
+ activations = activations.at[:, node_id].set(act)
317
+ except Exception as e:
318
+ print(f"Error at node {node_id}: {e}")
319
+
320
+ # Get output node activations
321
+ output = activations[:, -3:]
322
+
323
+ # Update recurrent nodes for next time step
324
+ # (In a real implementation, we'd need to store these)
325
+ for i in range(8, 12):
326
+ act = jnp.zeros(batch_size)
327
+ for conn_source, conn_weight in zip(eval_info['inputs'], eval_info['weights']):
328
+ if conn_source >= 20: # Only use output nodes
329
+ act += activations[:, conn_source] * conn_weight
330
+ activations = activations.at[:, i].set(jnp.tanh(act))
331
+
332
+ # Return to original shape
333
+ if len(original_shape) == 1:
334
+ output = output.reshape(-1)
335
+
336
+ return output
337
+ except Exception as e:
338
+ print(f"Error in forward pass: {e}")
339
+ return jnp.zeros(3)
340
+
341
+ def predict(self, inputs: jnp.ndarray) -> jnp.ndarray:
342
+ """Make a prediction for the given inputs.
343
+
344
+ Args:
345
+ inputs: Input array of shape (input_size,) or (batch_size, input_size)
346
+
347
+ Returns:
348
+ Predictions of shape (3,) for single input or (batch_size, 3) for batch
349
+ """
350
+ outputs = self.forward(inputs)
351
+
352
+ # Ensure correct output shape for volleyball (always 3 outputs)
353
+ if len(outputs.shape) == 1:
354
+ # Single input case - ensure shape (3,)
355
+ if outputs.shape[0] != 3:
356
+ print(f"Adjusting output shape from {outputs.shape} to (3,)")
357
+ return jnp.pad(outputs, (0, max(0, 3 - outputs.shape[0])))
358
+ return outputs
359
+ else:
360
+ # Batch case - ensure shape (batch_size, 3)
361
+ if outputs.shape[1] != 3:
362
+ print(f"Adjusting output shape from {outputs.shape} to (batch_size, 3)")
363
+ return jnp.pad(outputs, ((0, 0), (0, max(0, 3 - outputs.shape[1]))))
364
+ return outputs
365
+
366
+ def clone(self) -> 'Network':
367
+ """Create a copy of this network with a cloned genome."""
368
+ return Network(self.genome.clone())
369
+
370
+ def mutate(self, config: Dict):
371
+ """Mutate the network's genome."""
372
+ self.genome.mutate(config)
373
+ # Rebuild evaluation order after mutation
374
+ self._build_feed_forward_order()
375
+
376
+ def to_genome(self) -> Genome:
377
+ """Convert network back to genome representation."""
378
+ genome = Genome(self.input_size, self.output_size)
379
+ genome.node_genes = copy.deepcopy(self.node_genes)
380
+ genome.connection_genes = copy.deepcopy(self.connection_genes)
381
+ return genome
382
+
383
+ class BaseNetwork:
384
+ """Base Network class for NEAT."""
385
+
386
+ def __init__(self, n_inputs: int, n_outputs: int):
387
+ self.input_size = n_inputs
388
+ self.output_size = n_outputs
389
+ self.fitness = float('-inf')
390
+
391
+ # Initialize weights and biases with JAX
392
+ key = jax.random.PRNGKey(0)
393
+ # Use larger initial weights to encourage exploration
394
+ self.weights = jax.random.normal(key, (n_outputs, n_inputs)) * 0.5
395
+ # Add small positive bias to encourage some initial movement
396
+ self.bias = jnp.ones(n_outputs) * 0.1
397
+
398
+ def forward(self, x: jnp.ndarray) -> jnp.ndarray:
399
+ """Forward pass through the network."""
400
+ if x.ndim > 1:
401
+ # Batched input
402
+ h = jnp.dot(x, self.weights.T) + self.bias[None, :]
403
+ else:
404
+ # Single input
405
+ h = jnp.dot(x, self.weights.T) + self.bias
406
+ return jnp.tanh(h)
407
+
408
+ def get_params(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
409
+ """Get network parameters."""
410
+ return self.weights, self.bias
411
+
412
+ def set_params(self, params: Tuple[jnp.ndarray, jnp.ndarray]):
413
+ """Set network parameters."""
414
+ self.weights, self.bias = params
415
+
416
+ def get_weights_numpy(self) -> np.ndarray:
417
+ """Get weights as numpy array for visualization."""
418
+ return np.array(self.weights)
419
+
420
+ class NodeGene:
421
+ """Node gene containing node information."""
422
+ def __init__(self, node_id: int, node_type: str, activation: str = 'tanh'):
423
+ """Initialize node gene.
424
+
425
+ Args:
426
+ node_id: Node ID
427
+ node_type: Type of node ('input', 'hidden', or 'output')
428
+ activation: Activation function ('tanh', 'sigmoid', or 'relu')
429
+ """
430
+ self.id = node_id
431
+ self.type = node_type
432
+ self.activation = activation
433
+ # Initialize with larger random bias for hidden/output nodes
434
+ if node_type in ['hidden', 'output']:
435
+ key = jax.random.PRNGKey(node_id) # Use node_id as seed for reproducibility
436
+ self.bias = jax.random.normal(key, ()) * 0.5 # Increased from 0.1
437
+ else:
438
+ self.bias = 0.0 # No bias for input nodes
439
+
440
+ class ConnectionGene:
441
+ """Gene representing a connection between nodes."""
442
+ def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
443
+ self.source = source
444
+ self.target = target
445
+ # Initialize with larger weights if not provided
446
+ if weight is None:
447
+ key = jax.random.PRNGKey(hash((source, target)) % 2**32)
448
+ self.weight = jax.random.uniform(key, (), minval=-2.0, maxval=2.0)
449
+ else:
450
+ self.weight = weight
451
+ self.enabled = enabled
452
+ self.innovation = None # Will be set by NEAT