|
import numpy as np |
|
import torch |
|
|
|
|
|
def get_agent_id_feature(agent_id, agent_num): |
|
agent_id_feature = torch.zeros(agent_num) |
|
agent_id_feature[agent_id] = 1 |
|
return agent_id_feature |
|
|
|
|
|
def get_movement_feature(): |
|
|
|
movement_feature = torch.randint(0, 2, (8, )) |
|
return movement_feature |
|
|
|
|
|
def get_own_feature(): |
|
|
|
return torch.randn(10) |
|
|
|
|
|
def get_ally_visible_feature(): |
|
|
|
|
|
if np.random.random() > 0.5: |
|
ally_visible_feature = torch.randn(4) |
|
else: |
|
ally_visible_feature = torch.zeros(4) |
|
return ally_visible_feature |
|
|
|
|
|
def get_enemy_visible_feature(): |
|
|
|
|
|
if np.random.random() > 0.8: |
|
enemy_visible_feature = torch.randn(4) |
|
else: |
|
enemy_visible_feature = torch.zeros(4) |
|
return enemy_visible_feature |
|
|
|
|
|
def get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
def get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num): |
|
|
|
|
|
|
|
ally_center_feature = torch.randn(8) |
|
enemy_center_feature = torch.randn(8) |
|
return torch.cat([ally_center_feature, enemy_center_feature]) |
|
|
|
|
|
def get_as_global_state(agent_id, ally_agent_num, enemy_agent_num): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
def test_global_state(): |
|
ally_agent_num = 3 |
|
enemy_agent_num = 5 |
|
|
|
for agent_id in range(ally_agent_num): |
|
ind_global_state = get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num) |
|
assert isinstance(ind_global_state, torch.Tensor) |
|
|
|
for agent_id in range(ally_agent_num): |
|
ep_global_state = get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num) |
|
assert isinstance(ep_global_state, torch.Tensor) |
|
|
|
for agent_id in range(ally_agent_num): |
|
as_global_state = get_as_global_state(agent_id, ally_agent_num, enemy_agent_num) |
|
assert isinstance(as_global_state, torch.Tensor) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_global_state() |
|
|