Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from kinetix.models.actor_critic import ( | |
| ActorCriticPixelsRNN, | |
| ActorCriticSymbolicRNN, | |
| ) | |
| from kinetix.models.transformer_model import ActorCriticTransformer | |
| def make_network_from_config(env, env_params, config, network_kws={}): | |
| env_name = config["env_name"] | |
| if "MultiDiscrete" in env_name: | |
| action_mode = "multi_discrete" | |
| elif "Discrete" in env_name: | |
| action_mode = "discrete" | |
| elif "Continuous" in env_name: | |
| action_mode = "continuous" | |
| elif "Hybrid" in env_name: | |
| action_mode = "hybrid" | |
| else: | |
| raise ValueError(f"Unknown action mode for {env_name}") | |
| action_dim = ( | |
| env.action_space(env_params).shape[0] if action_mode == "continuous" else env.action_space(env_params).n | |
| ) | |
| if "hybrid_action_continuous_dim" not in network_kws: | |
| network_kws["hybrid_action_continuous_dim"] = action_dim | |
| if "multi_discrete_number_of_dims_per_distribution" not in network_kws: | |
| num_joint_bindings = config["static_env_params"]["num_motor_bindings"] | |
| num_thruster_bindings = config["static_env_params"]["num_thruster_bindings"] | |
| network_kws["multi_discrete_number_of_dims_per_distribution"] = [3 for _ in range(num_joint_bindings)] + [ | |
| 2 for _ in range(num_thruster_bindings) | |
| ] | |
| network_kws["recurrent"] = config.get("recurrent_model", True) | |
| if "Pixels" in env_name: | |
| cls_to_use = ActorCriticPixelsRNN | |
| elif "Symbolic" in env_name or "Blind" in env_name: | |
| cls_to_use = ActorCriticSymbolicRNN | |
| if "Entity" in env_name: | |
| network = ActorCriticTransformer( | |
| action_dim=action_dim, | |
| fc_layer_width=config["fc_layer_width"], | |
| fc_layer_depth=config["fc_layer_depth"], | |
| action_mode=action_mode, | |
| num_heads=config["num_heads"], | |
| transformer_depth=config["transformer_depth"], | |
| transformer_size=config["transformer_size"], | |
| transformer_encoder_size=config["transformer_encoder_size"], | |
| aggregate_mode=config["aggregate_mode"], | |
| full_attention_mask=config["full_attention_mask"], | |
| activation=config["activation"], | |
| **network_kws, | |
| ) | |
| else: | |
| network = cls_to_use( | |
| action_dim, | |
| fc_layer_width=config["fc_layer_width"], | |
| fc_layer_depth=config["fc_layer_depth"], | |
| activation=config["activation"], | |
| action_mode=action_mode, | |
| **network_kws, | |
| ) | |
| return network | |