eyad-silx commited on
Commit
8e3abc5
·
verified ·
1 Parent(s): 41ff170

Upload neat\genome.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//genome.py +454 -0
neat//genome.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NEAT Genome implementation.
2
+
3
+ This module implements the core NEAT genome structure and operations.
4
+ Each genome represents a neural network with nodes (neurons) and connections (synapses).
5
+ The genome can be mutated to evolve the network structure and weights over time.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ import jax.numpy as jnp
10
+ import jax.random as jrandom
11
+ from typing import Dict, List, Tuple, Optional
12
+ import time
13
+ import random
14
+ import numpy as np
15
+
16
+ @dataclass
17
+ class NodeGene:
18
+ """Node gene containing activation function and type.
19
+
20
+ Attributes:
21
+ node_id: Unique identifier for this node
22
+ node_type: Type of node ('input', 'hidden', 'recurrent', or 'output')
23
+ activation: Activation function ('tanh', 'relu', 'sigmoid', or 'linear')
24
+ """
25
+ node_id: int
26
+ node_type: str # 'input', 'hidden', 'recurrent', or 'output'
27
+ activation: str # 'tanh', 'relu', 'sigmoid', or 'linear'
28
+
29
+ @dataclass
30
+ class ConnectionGene:
31
+ """Connection gene containing connection properties.
32
+
33
+ Attributes:
34
+ source: ID of source node
35
+ target: ID of target node
36
+ weight: Connection weight
37
+ enabled: Whether connection is enabled
38
+ innovation: Unique innovation number for this connection
39
+ """
40
+ source: int
41
+ target: int
42
+ weight: float
43
+ enabled: bool = True
44
+ innovation: int = 0
45
+
46
+ class Genome:
47
+ """NEAT Genome implementation.
48
+
49
+ A genome represents a neural network as a collection of node and connection genes.
50
+ The network topology can be modified through mutation operations.
51
+
52
+ Attributes:
53
+ input_size: Number of input nodes
54
+ output_size: Number of output nodes
55
+ node_genes: Dictionary mapping node IDs to NodeGene objects
56
+ connection_genes: List of ConnectionGene objects
57
+ key: Random key for reproducible randomness
58
+ innovation_number: Counter for assigning unique innovation numbers
59
+ """
60
+
61
+ def __init__(self, input_size: int, output_size: int):
62
+ """Initialize genome with specified number of inputs and outputs.
63
+
64
+ Args:
65
+ input_size: Number of input nodes
66
+ output_size: Number of output nodes (must be 3 for volleyball)
67
+ """
68
+ self.input_size = input_size
69
+ self.output_size = output_size
70
+ self.node_genes: Dict[int, NodeGene] = {}
71
+ self.connection_genes: List[ConnectionGene] = []
72
+
73
+ # Initialize random key
74
+ timestamp = int(time.time() * 1000)
75
+ self.key = jrandom.PRNGKey(hash((input_size, output_size, timestamp)) % (2**32))
76
+
77
+ # Counter for assigning unique innovation numbers
78
+ self.innovation_number = 0
79
+
80
+ # Initialize minimal network structure
81
+ self._init_minimal()
82
+
83
+ def _init_minimal(self):
84
+ """Initialize minimal feed-forward network structure.
85
+
86
+ Network structure:
87
+ - Input nodes [0-7]: Game state inputs
88
+ - Hidden layer 1 [8-15]: First processing layer (8 nodes)
89
+ - Hidden layer 2 [16-23]: Second processing layer (8 nodes)
90
+ - Output nodes [24-26]: Action outputs (left, right, jump)
91
+
92
+ Using larger initial weights for faster learning:
93
+ - Input->Hidden1: N(0, 2.0) for strong initial responses
94
+ - Hidden1->Hidden2: N(0, 2.0) for feature processing
95
+ - Hidden2->Output: N(0, 4.0) for decisive actions
96
+ """
97
+ # Create input nodes (0-7)
98
+ for i in range(8): # Only 8 inputs used
99
+ self.node_genes[i] = NodeGene(
100
+ node_id=i,
101
+ node_type='input',
102
+ activation='linear' # Input nodes are always linear
103
+ )
104
+
105
+ # Create first hidden layer (8-15)
106
+ hidden1_size = 8
107
+ hidden1_start = 8 # Right after inputs
108
+ for i in range(hidden1_size):
109
+ node_id = hidden1_start + i
110
+ self.node_genes[node_id] = NodeGene(
111
+ node_id=node_id,
112
+ node_type='hidden',
113
+ activation='relu' # ReLU for faster learning
114
+ )
115
+
116
+ # Connect all inputs to this hidden node
117
+ for input_id in range(8):
118
+ weight = float(jrandom.normal(self.key) * 2.0)
119
+ self.connection_genes.append(ConnectionGene(
120
+ source=input_id,
121
+ target=node_id,
122
+ weight=weight,
123
+ enabled=True,
124
+ innovation=self.innovation_number
125
+ ))
126
+ self.innovation_number += 1
127
+
128
+ # Create second hidden layer (16-23)
129
+ hidden2_size = 8
130
+ hidden2_start = hidden1_start + hidden1_size
131
+ for i in range(hidden2_size):
132
+ node_id = hidden2_start + i
133
+ self.node_genes[node_id] = NodeGene(
134
+ node_id=node_id,
135
+ node_type='hidden',
136
+ activation='relu' # ReLU for faster learning
137
+ )
138
+
139
+ # Connect all hidden1 nodes to this hidden2 node
140
+ for h1_id in range(hidden1_start, hidden1_start + hidden1_size):
141
+ weight = float(jrandom.normal(self.key) * 2.0)
142
+ self.connection_genes.append(ConnectionGene(
143
+ source=h1_id,
144
+ target=node_id,
145
+ weight=weight,
146
+ enabled=True,
147
+ innovation=self.innovation_number
148
+ ))
149
+ self.innovation_number += 1
150
+
151
+ # Create output nodes (24-26)
152
+ output_start = hidden2_start + hidden2_size
153
+ for i in range(self.output_size):
154
+ node_id = output_start + i
155
+ self.node_genes[node_id] = NodeGene(
156
+ node_id=node_id,
157
+ node_type='output',
158
+ activation='tanh' # tanh for [-1,1] outputs
159
+ )
160
+
161
+ # Connect all hidden2 nodes to this output
162
+ for h2_id in range(hidden2_start, hidden2_start + hidden2_size):
163
+ weight = float(jrandom.normal(self.key) * 4.0) # Larger weights for outputs
164
+ self.connection_genes.append(ConnectionGene(
165
+ source=h2_id,
166
+ target=node_id,
167
+ weight=weight,
168
+ enabled=True,
169
+ innovation=self.innovation_number
170
+ ))
171
+ self.innovation_number += 1
172
+
173
+ def mutate(self, config: Dict):
174
+ """Mutate the genome by modifying weights and network structure.
175
+
176
+ Args:
177
+ config: Dictionary containing mutation parameters:
178
+ - weight_mutation_rate: Probability of mutating each weight
179
+ - weight_mutation_power: Standard deviation for weight mutations
180
+ - add_node_rate: Probability of adding a new node
181
+ - add_connection_rate: Probability of adding a new connection
182
+ """
183
+ # Mutate connection weights
184
+ for conn in self.connection_genes:
185
+ if jrandom.uniform(self.key) < config['weight_mutation_rate']:
186
+ # Get new random key
187
+ self.key, subkey = jrandom.split(self.key)
188
+ # Add random value from normal distribution
189
+ conn.weight += float(jrandom.normal(subkey) * config['weight_mutation_power'])
190
+
191
+ # Add new nodes (disabled for now since we're using fixed topology)
192
+ if config['add_node_rate'] > 0:
193
+ if jrandom.uniform(self.key) < config['add_node_rate']:
194
+ self._add_node()
195
+
196
+ # Add new connections (disabled for now)
197
+ if config['add_connection_rate'] > 0:
198
+ if jrandom.uniform(self.key) < config['add_connection_rate']:
199
+ self._add_connection()
200
+
201
+ def _add_node(self):
202
+ """Add a new node by splitting an existing connection."""
203
+ if not self.connection_genes:
204
+ return
205
+
206
+ # Choose a random connection to split
207
+ conn_to_split = np.random.choice(self.connection_genes)
208
+ conn_to_split.enabled = False
209
+
210
+ # Create new node
211
+ new_node_id = max(self.node_genes.keys()) + 1
212
+ self.node_genes[new_node_id] = NodeGene(
213
+ node_id=new_node_id,
214
+ node_type='hidden',
215
+ activation='relu'
216
+ )
217
+
218
+ # Create two new connections
219
+ self.connection_genes.extend([
220
+ ConnectionGene(
221
+ source=conn_to_split.source,
222
+ target=new_node_id,
223
+ weight=1.0,
224
+ enabled=True,
225
+ innovation=self.innovation_number
226
+ ),
227
+ ConnectionGene(
228
+ source=new_node_id,
229
+ target=conn_to_split.target,
230
+ weight=conn_to_split.weight,
231
+ enabled=True,
232
+ innovation=self.innovation_number + 1
233
+ )
234
+ ])
235
+ self.innovation_number += 2
236
+
237
+ def _add_connection(self):
238
+ """Add a new connection between two unconnected nodes."""
239
+ # Get list of all possible connections
240
+ existing_connections = {(c.source, c.target) for c in self.connection_genes}
241
+ possible_connections = []
242
+
243
+ for source in self.node_genes:
244
+ for target in self.node_genes:
245
+ # Skip if connection already exists
246
+ if (source, target) in existing_connections:
247
+ continue
248
+
249
+ # Skip if would create cycle (except recurrent)
250
+ if self.node_genes[source].node_type != 'recurrent' and \
251
+ self.would_create_cycle(source, target):
252
+ continue
253
+
254
+ possible_connections.append((source, target))
255
+
256
+ if possible_connections:
257
+ # Choose random connection
258
+ source, target = random.choice(possible_connections)
259
+
260
+ # Create new connection
261
+ weight = float(jrandom.normal(self.key) * 1.0)
262
+ self.connection_genes.append(ConnectionGene(
263
+ source=source,
264
+ target=target,
265
+ weight=weight,
266
+ enabled=True,
267
+ innovation=self.innovation_number
268
+ ))
269
+ self.innovation_number += 1
270
+
271
+ def would_create_cycle(self, source: int, target: int) -> bool:
272
+ """Check if adding connection would create cycle in network.
273
+
274
+ Args:
275
+ source: Source node ID
276
+ target: Target node ID
277
+
278
+ Returns:
279
+ True if connection would create cycle, False otherwise
280
+ """
281
+ # Skip cycle detection for recurrent connections
282
+ if self.node_genes[source].node_type == 'recurrent' or \
283
+ self.node_genes[target].node_type == 'recurrent':
284
+ return False
285
+
286
+ # Do depth-first search from target to see if we can reach source
287
+ visited = set()
288
+
289
+ def dfs(node: int) -> bool:
290
+ if node == source:
291
+ return True
292
+ if node in visited:
293
+ return False
294
+
295
+ visited.add(node)
296
+ for conn in self.connection_genes:
297
+ if conn.source == node and conn.enabled:
298
+ if dfs(conn.target):
299
+ return True
300
+ return False
301
+
302
+ return dfs(target)
303
+
304
+ def add_node_between(self, source: int, target: int):
305
+ """Add a new node between two nodes, splitting an existing connection.
306
+
307
+ Args:
308
+ source: Source node ID
309
+ target: Target node ID
310
+ """
311
+ # Find and disable the existing connection
312
+ for conn in self.connection_genes:
313
+ if conn.source == source and conn.target == target and conn.enabled:
314
+ conn.enabled = False
315
+
316
+ # Create new node
317
+ new_id = max(self.node_genes.keys()) + 1
318
+ self.node_genes[new_id] = NodeGene(
319
+ node_id=new_id,
320
+ node_type='hidden',
321
+ activation='relu'
322
+ )
323
+
324
+ # Create two new connections
325
+ self.connection_genes.extend([
326
+ ConnectionGene(
327
+ source=source,
328
+ target=new_id,
329
+ weight=1.0,
330
+ enabled=True,
331
+ innovation=self.innovation_number
332
+ ),
333
+ ConnectionGene(
334
+ source=new_id,
335
+ target=target,
336
+ weight=conn.weight,
337
+ enabled=True,
338
+ innovation=self.innovation_number + 1
339
+ )
340
+ ])
341
+ self.innovation_number += 2
342
+ break
343
+
344
+ def add_connection(self, source: int, target: int, weight: Optional[float] = None) -> bool:
345
+ """Add a new connection between two nodes.
346
+
347
+ Args:
348
+ source: Source node ID
349
+ target: Target node ID
350
+ weight: Optional connection weight. If None, a random weight is generated.
351
+
352
+ Returns:
353
+ True if connection was added, False if invalid or already exists
354
+ """
355
+ # Check if connection already exists
356
+ if any(c.source == source and c.target == target for c in self.connection_genes):
357
+ return False
358
+
359
+ # Validate nodes exist
360
+ if source not in self.node_genes or target not in self.node_genes:
361
+ return False
362
+
363
+ # Ensure feed-forward (no cycles)
364
+ if source >= target: # Simple way to ensure feed-forward
365
+ return False
366
+
367
+ # Generate random weight if not provided
368
+ if weight is None:
369
+ weight = float(jrandom.normal(self.key) * 1.0)
370
+
371
+ # Add new connection
372
+ self.connection_genes.append(ConnectionGene(
373
+ source=source,
374
+ target=target,
375
+ weight=weight,
376
+ enabled=True,
377
+ innovation=self.innovation_number
378
+ ))
379
+ self.innovation_number += 1
380
+ return True
381
+
382
+ def crossover(self, other: 'Genome', key: jnp.ndarray) -> 'Genome':
383
+ """Perform crossover between two genomes.
384
+
385
+ Args:
386
+ other: Other parent genome
387
+ key: JAX PRNG key
388
+
389
+ Returns:
390
+ Child genome
391
+ """
392
+ # Create child genome
393
+ child = Genome(self.input_size, self.output_size)
394
+
395
+ # Inherit node genes
396
+ for node_id in self.node_genes:
397
+ if node_id in other.node_genes:
398
+ # Inherit randomly from either parent
399
+ if jrandom.uniform(key) < 0.5:
400
+ child.node_genes[node_id] = self.node_genes[node_id]
401
+ else:
402
+ child.node_genes[node_id] = other.node_genes[node_id]
403
+ else:
404
+ # Inherit from fitter parent
405
+ child.node_genes[node_id] = self.node_genes[node_id]
406
+
407
+ # Inherit connection genes
408
+ for conn in self.connection_genes:
409
+ if conn.innovation in [c.innovation for c in other.connection_genes]:
410
+ # Inherit randomly from either parent
411
+ other_conn = next(c for c in other.connection_genes if c.innovation == conn.innovation)
412
+ if jrandom.uniform(key) < 0.5:
413
+ child.connection_genes.append(ConnectionGene(
414
+ source=conn.source,
415
+ target=conn.target,
416
+ weight=conn.weight,
417
+ enabled=conn.enabled,
418
+ innovation=conn.innovation
419
+ ))
420
+ else:
421
+ child.connection_genes.append(ConnectionGene(
422
+ source=other_conn.source,
423
+ target=other_conn.target,
424
+ weight=other_conn.weight,
425
+ enabled=other_conn.enabled,
426
+ innovation=other_conn.innovation
427
+ ))
428
+ else:
429
+ # Inherit from fitter parent
430
+ child.connection_genes.append(ConnectionGene(
431
+ source=conn.source,
432
+ target=conn.target,
433
+ weight=conn.weight,
434
+ enabled=conn.enabled,
435
+ innovation=conn.innovation
436
+ ))
437
+
438
+ return child
439
+
440
+ def clone(self) -> 'Genome':
441
+ """Create a copy of this genome.
442
+
443
+ Returns:
444
+ Copy of genome
445
+ """
446
+ clone = Genome(self.input_size, self.output_size)
447
+ clone.node_genes = self.node_genes.copy()
448
+ clone.connection_genes = [ConnectionGene(**conn.__dict__) for conn in self.connection_genes]
449
+ return clone
450
+
451
+ @property
452
+ def n_nodes(self) -> int:
453
+ """Get total number of nodes in the genome."""
454
+ return len(self.node_genes)