Spaces:
Paused
Paused
File size: 367 Bytes
7f3c2df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import abc
class Policy(abc.ABC):
def __init__(self, device, *args, **kwargs):
self.device = device
@abc.abstractmethod
def get_action(self, obs_dict, **kwargs):
"""Predict an action based on the input observation """
pass
@abc.abstractmethod
def eval(self):
"""Set the policy to evaluation mode"""
pass |