from user_model.configuration import UserModelConfig from transformers import PreTrainedModel import tensorflow as tf class UserModel(PreTrainedModel): config_class = UserModelConfig def __init__(self, config): super().__init__(config) self.model = tf.saved_model.load('tf_retrieval_user_model') def forward(self, user_id): return self.model(user_id)