File size: 367 Bytes
7f3c2df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import abc


class Policy(abc.ABC):
    def __init__(self, device, *args, **kwargs):
        self.device = device

    @abc.abstractmethod
    def get_action(self, obs_dict, **kwargs):
        """Predict an action based on the input observation """
        pass

    @abc.abstractmethod
    def eval(self):
        """Set the policy to evaluation mode"""
        pass