|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pushes to first target, waits, then pushes to second target.""" |
|
|
|
import .block_pushing.oracles.oriented_push_oracle as oriented_push_oracle_module |
|
import numpy as np |
|
from tf_agents.trajectories import policy_step |
|
from tf_agents.trajectories import time_step as ts |
|
from tf_agents.typing import types |
|
|
|
|
|
import pybullet |
|
|
|
|
|
class DiscontinuousOrientedPushOracle(oriented_push_oracle_module.OrientedPushOracle): |
|
"""Pushes to first target, waits, then pushes to second target.""" |
|
|
|
def __init__(self, env, goal_tolerance=0.04, wait=0): |
|
super(DiscontinuousOrientedPushOracle, self).__init__(env) |
|
self._countdown = 0 |
|
self._wait = wait |
|
self._goal_dist_tolerance = goal_tolerance |
|
|
|
def reset(self): |
|
self.phase = "move_to_pre_block" |
|
self._countdown = 0 |
|
|
|
def _action(self, time_step, policy_state): |
|
if time_step.is_first(): |
|
self.reset() |
|
|
|
self._current_target = "target" |
|
self._has_switched = False |
|
|
|
def _block_target_dist(block, target): |
|
dist = np.linalg.norm( |
|
time_step.observation["%s_translation" % block] |
|
- time_step.observation["%s_translation" % target] |
|
) |
|
return dist |
|
|
|
d1 = _block_target_dist("block", "target") |
|
if d1 < self._goal_dist_tolerance and not self._has_switched: |
|
self._countdown = self._wait |
|
|
|
self._has_switched = True |
|
self._current_target = "target2" |
|
|
|
xy_delta = self._get_action_for_block_target( |
|
time_step, block="block", target=self._current_target |
|
) |
|
|
|
if self._countdown > 0: |
|
xy_delta = np.zeros_like(xy_delta) |
|
self._countdown -= 1 |
|
|
|
return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32)) |
|
|