File size: 3,005 Bytes
6340080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import torch


def get_agent_id_feature(agent_id, agent_num):
    agent_id_feature = torch.zeros(agent_num)
    agent_id_feature[agent_id] = 1
    return agent_id_feature


def get_movement_feature():
    # for simplicity, we use random movement feature here
    movement_feature = torch.randint(0, 2, (8, ))
    return movement_feature


def get_own_feature():
    # for simplicity, we use random own feature here
    return torch.randn(10)


def get_ally_visible_feature():
    # this function only return the visible feature of one ally
    # for simplicity, we use random tensor as ally visible feature while zero tensor as ally invisible feature
    if np.random.random() > 0.5:
        ally_visible_feature = torch.randn(4)
    else:
        ally_visible_feature = torch.zeros(4)
    return ally_visible_feature


def get_enemy_visible_feature():
    # this function only return the visible feature of one enemy
    # for simplicity, we use random tensor as enemy visible feature while zero tensor as enemy invisible feature
    if np.random.random() > 0.8:
        enemy_visible_feature = torch.randn(4)
    else:
        enemy_visible_feature = torch.zeros(4)
    return enemy_visible_feature


def get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num):
    # You need to implement this function
    raise NotImplementedError


def get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num):
    # In many multi-agent environments such as SMAC, the global state is the simplified version of the combination
    # of all the agent's independent state, and the concrete implementation depends on the characteris of environment.
    # For simplicity, we use random feature here.
    ally_center_feature = torch.randn(8)
    enemy_center_feature = torch.randn(8)
    return torch.cat([ally_center_feature, enemy_center_feature])


def get_as_global_state(agent_id, ally_agent_num, enemy_agent_num):
    # You need to implement this function
    raise NotImplementedError


def test_global_state():
    ally_agent_num = 3
    enemy_agent_num = 5
    # get independent global state, which usually used in decentralized training
    for agent_id in range(ally_agent_num):
        ind_global_state = get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num)
        assert isinstance(ind_global_state, torch.Tensor)
    # get environment provide global state, which is the same for all agents, used in centralized training
    for agent_id in range(ally_agent_num):
        ep_global_state = get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num)
        assert isinstance(ep_global_state, torch.Tensor)
    # get naive agent-specific global state, which is the specific for each agent, used in centralized training
    for agent_id in range(ally_agent_num):
        as_global_state = get_as_global_state(agent_id, ally_agent_num, enemy_agent_num)
        assert isinstance(as_global_state, torch.Tensor)


if __name__ == "__main__":
    test_global_state()