Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from agent.dreamer import DreamerAgent, stop_gradient | |
| import agent.dreamer_utils as common | |
| class Disagreement(nn.Module): | |
| def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None): | |
| super().__init__() | |
| if pred_dim is None: pred_dim = obs_dim | |
| self.ensemble = nn.ModuleList([ | |
| nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim), | |
| nn.ReLU(), nn.Linear(hidden_dim, pred_dim)) | |
| for _ in range(n_models) | |
| ]) | |
| def forward(self, obs, action, next_obs): | |
| assert obs.shape[0] == next_obs.shape[0] | |
| assert obs.shape[0] == action.shape[0] | |
| errors = [] | |
| for model in self.ensemble: | |
| next_obs_hat = model(torch.cat([obs, action], dim=-1)) | |
| model_error = torch.norm(next_obs - next_obs_hat, | |
| dim=-1, | |
| p=2, | |
| keepdim=True) | |
| errors.append(model_error) | |
| return torch.cat(errors, dim=1) | |
| def get_disagreement(self, obs, action): | |
| assert obs.shape[0] == action.shape[0] | |
| preds = [] | |
| for model in self.ensemble: | |
| next_obs_hat = model(torch.cat([obs, action], dim=-1)) | |
| preds.append(next_obs_hat) | |
| preds = torch.stack(preds, dim=0) | |
| return torch.var(preds, dim=0).mean(dim=-1) | |
| class Plan2Explore(DreamerAgent): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| in_dim = self.wm.inp_size | |
| pred_dim = self.wm.embed_dim | |
| self.hidden_dim = pred_dim | |
| self.reward_free = True | |
| self.disagreement = Disagreement(in_dim, self.act_dim, | |
| self.hidden_dim, pred_dim=pred_dim).to(self.device) | |
| # optimizers | |
| self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) | |
| self.disagreement.train() | |
| self.requires_grad_(requires_grad=False) | |
| def update_disagreement(self, obs, action, next_obs, step): | |
| metrics = dict() | |
| error = self.disagreement(obs, action, next_obs) | |
| loss = error.mean() | |
| metrics.update(self.disagreement_opt(loss, self.disagreement.parameters())) | |
| metrics['disagreement_loss'] = loss.item() | |
| return metrics | |
| def compute_intr_reward(self, seq): | |
| obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:]) | |
| intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device) | |
| if len(action.shape) > 2: | |
| B, T, _ = action.shape | |
| obs = obs.reshape(B*T, -1) | |
| action = action.reshape(B*T, -1) | |
| reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1) | |
| else: | |
| reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1) | |
| intr_rew[1:] = reward | |
| return intr_rew | |
| def update(self, data, step): | |
| metrics = {} | |
| B, T, _ = data['action'].shape | |
| state, outputs, mets = self.wm.update(data, state=None) | |
| metrics.update(mets) | |
| start = outputs['post'] | |
| start = {k: stop_gradient(v) for k,v in start.items()} | |
| if self.reward_free: | |
| T = T-1 | |
| inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1) | |
| action = data['action'][:, 1:].reshape(B*T, -1) | |
| out = stop_gradient(outputs['embed'][:,1:]).reshape(B*T,-1) | |
| with common.RequiresGrad(self.disagreement): | |
| with torch.cuda.amp.autocast(enabled=self._use_amp): | |
| metrics.update( | |
| self.update_disagreement(inp, action, out, step)) | |
| metrics.update(self._acting_behavior.update( | |
| self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward)) | |
| else: | |
| reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean | |
| metrics.update(self._acting_behavior.update( | |
| self.wm, start, data['is_terminal'], reward_fn)) | |
| return state, metrics |