Spaces:
Build error
Build error
| "Client-server interface custom implementation for seizure detection models." | |
| from common import SEIZURE_DETECTION_MODEL_PATH | |
| from concrete import fhe | |
| from seizure_detection import SeizureDetector | |
| class FHEServer: | |
| """Server interface to run a FHE circuit for seizure detection.""" | |
| def __init__(self, model_path): | |
| """Initialize the FHE interface. | |
| Args: | |
| model_path (Path): The path to the directory where the circuit is saved. | |
| """ | |
| self.model_path = model_path | |
| # Load the FHE circuit | |
| self.server = fhe.Server.load(self.model_path / "server.zip") | |
| def run(self, serialized_encrypted_image, serialized_evaluation_keys): | |
| """Run seizure detection on the server over an encrypted image. | |
| Args: | |
| serialized_encrypted_image (bytes): The encrypted and serialized image. | |
| serialized_evaluation_keys (bytes): The serialized evaluation keys. | |
| Returns: | |
| bytes: The encrypted boolean output indicating seizure detection. | |
| """ | |
| # Deserialize the encrypted input image and the evaluation keys | |
| encrypted_image = fhe.Value.deserialize(serialized_encrypted_image) | |
| evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys) | |
| # Execute the seizure detection in FHE | |
| encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys) | |
| # Serialize the encrypted output | |
| serialized_encrypted_output = encrypted_output.serialize() | |
| return serialized_encrypted_output | |
| class FHEDev: | |
| """Development interface to save and load the seizure detection model.""" | |
| def __init__(self, seizure_detector, model_path): | |
| """Initialize the FHE interface. | |
| Args: | |
| seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface. | |
| model_path (str): The path to the directory where the circuit is saved. | |
| """ | |
| self.seizure_detector = seizure_detector | |
| self.model_path = model_path | |
| self.model_path.mkdir(parents=True, exist_ok=True) | |
| def save(self): | |
| """Export all needed artifacts for the client and server interfaces.""" | |
| assert self.seizure_detector.fhe_circuit is not None, ( | |
| "The model must be compiled before saving it." | |
| ) | |
| # Save the circuit for the server, using the via_mlir in order to handle cross-platform | |
| # execution | |
| path_circuit_server = self.model_path / "server.zip" | |
| self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True) | |
| # Save the circuit for the client | |
| path_circuit_client = self.model_path / "client.zip" | |
| self.seizure_detector.fhe_circuit.client.save(path_circuit_client) | |
| class FHEClient: | |
| """Client interface to encrypt and decrypt FHE data associated to a SeizureDetector.""" | |
| def __init__(self, key_dir=None): | |
| """Initialize the FHE interface. | |
| Args: | |
| model_path (Path): The path to the directory where the circuit is saved. | |
| key_dir (Path): The path to the directory where the keys are stored. Default to None. | |
| """ | |
| self.model_path = SEIZURE_DETECTION_MODEL_PATH | |
| self.key_dir = key_dir | |
| # If model_path does not exist raise | |
| assert self.model_path.exists(), f"{self.model_path} does not exist. Please specify a valid path." | |
| # Load the client | |
| self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir) | |
| # Instantiate the seizure detector | |
| self.seizure_detector = SeizureDetector() | |
| def generate_private_and_evaluation_keys(self, force=False): | |
| """Generate the private and evaluation keys. | |
| Args: | |
| force (bool): If True, regenerate the keys even if they already exist. | |
| """ | |
| self.client.keygen(force) | |
| def get_serialized_evaluation_keys(self): | |
| """Get the serialized evaluation keys. | |
| Returns: | |
| bytes: The evaluation keys. | |
| """ | |
| return self.client.evaluation_keys.serialize() | |
| def encrypt_serialize(self, input_image): | |
| """Encrypt and serialize the input image in the clear. | |
| Args: | |
| input_image (numpy.ndarray): The image to encrypt and serialize. | |
| Returns: | |
| bytes: The pre-processed, encrypted and serialized image. | |
| """ | |
| # Encrypt the image | |
| encrypted_image = self.client.encrypt(input_image) | |
| # Serialize the encrypted image to be sent to the server | |
| serialized_encrypted_image = encrypted_image.serialize() | |
| return serialized_encrypted_image | |
| def deserialize_decrypt_post_process(self, serialized_encrypted_output): | |
| """Deserialize, decrypt and post-process the output in the clear. | |
| Args: | |
| serialized_encrypted_output (bytes): The serialized and encrypted output. | |
| Returns: | |
| bool: The decrypted and deserialized boolean indicating seizure detection. | |
| """ | |
| # Deserialize the encrypted output | |
| encrypted_output = fhe.Value.deserialize(serialized_encrypted_output) | |
| # Decrypt the output | |
| output = self.client.decrypt(encrypted_output) | |
| # Post-process the output (if needed) | |
| seizure_detected = self.seizure_detector.post_processing(output) | |
| return seizure_detected | |