Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, Form, UploadFile, Body | |
from fastapi.responses import JSONResponse, Response | |
from concrete.ml.deployment import FHEModelServer | |
import numpy as np | |
from pydantic import BaseModel | |
from concrete.ml.deployment import FHEModelClient | |
import subprocess | |
from pathlib import Path | |
from utils import ( | |
CLIENT_DIR, | |
CURRENT_DIR, | |
DEPLOYMENT_DIR, | |
SERVER_DIR, | |
INPUT_BROWSER_LIMIT, | |
KEYS_DIR, | |
SERVER_URL, | |
TARGET_COLUMNS, | |
TRAINING_FILENAME, | |
clean_directory, | |
get_disease_name, | |
load_data, | |
pretty_print, | |
) | |
import time | |
from typing import List | |
# Load the FHE server | |
# FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR) | |
class Symptoms(BaseModel): | |
user_symptoms: List[str] | |
app = FastAPI() | |
def greet_json(): | |
return {"Hello": "World!"} | |
def root(): | |
""" | |
Root endpoint of the health prediction API. | |
Returns: | |
dict: The welcome message. | |
""" | |
return {"message": "Welcome to your disease prediction with FHE!"} | |
def send_input( | |
user_id: str = Form(), | |
files: List[UploadFile] = File(), | |
): | |
"""Send the inputs to the server.""" | |
print("\nSend the data to the server ............\n") | |
# Receive the Client's files (Evaluation key + Encrypted symptoms) | |
evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key" | |
encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input" | |
# Save the files using the above paths | |
with encrypted_input_path.open("wb") as encrypted_input, evaluation_key_path.open( | |
"wb" | |
) as evaluation_key: | |
encrypted_input.write(files[0].file.read()) | |
evaluation_key.write(files[1].file.read()) | |
def run_fhe( | |
user_id: str = Form(), | |
): | |
"""Inference in FHE.""" | |
print("\nRun in FHE in the server ............\n") | |
evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key" | |
encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input" | |
# Read the files (Evaluation key + Encrypted symptoms) using the above paths | |
with encrypted_input_path.open("rb") as encrypted_output_file, evaluation_key_path.open( | |
"rb" | |
) as evaluation_key_file: | |
encrypted_output = encrypted_output_file.read() | |
evaluation_key = evaluation_key_file.read() | |
# Run the FHE execution | |
start = time.time() | |
encrypted_output = FHE_SERVER.run(encrypted_output, evaluation_key) | |
assert isinstance(encrypted_output, bytes) | |
fhe_execution_time = round(time.time() - start, 2) | |
# Retrieve the encrypted output path | |
encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output" | |
# Write the file using the above path | |
with encrypted_output_path.open("wb") as f: | |
f.write(encrypted_output) | |
return JSONResponse(content=fhe_execution_time) | |
def get_output(user_id: str = Form()): | |
"""Retrieve the encrypted output from the server.""" | |
print("\nGet the output from the server ............\n") | |
# Path where the encrypted output is saved | |
encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output" | |
# Read the file using the above path | |
with encrypted_output_path.open("rb") as f: | |
encrypted_output = f.read() | |
time.sleep(1) | |
# Send the encrypted output | |
return Response(encrypted_output) | |
def generate_keys(symptoms: Symptoms): | |
""" | |
Endpoint pour générer des clés basées sur les symptômes de l'utilisateur. | |
Args: | |
symptoms (Symptoms): Les symptômes de l'utilisateur. | |
Returns: | |
JSONResponse: Réponse contenant les clés générées et l'ID utilisateur. | |
""" | |
# Appel de la fonction de nettoyage | |
clean_directory() | |
# Vérification si la liste des symptômes est vide | |
if not symptoms.user_symptoms: | |
return JSONResponse( | |
status_code=400, content={"error": "Veuillez soumettre vos symptômes en premier."} | |
) | |
# Génération d'un ID utilisateur aléatoire | |
user_id = np.random.randint(0, 2**32) | |
print(f"Votre ID utilisateur est : {user_id}....") | |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") | |
client.load() | |
# Création des clés privées et d'évaluation côté client | |
client.generate_private_and_evaluation_keys() | |
# Récupération des clés d'évaluation sérialisées | |
serialized_evaluation_keys = client.get_serialized_evaluation_keys() | |
assert isinstance(serialized_evaluation_keys, bytes) | |
# Sauvegarde de la clé d'évaluation | |
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" | |
with evaluation_key_path.open("wb") as f: | |
f.write(serialized_evaluation_keys) | |
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] | |
return JSONResponse( | |
content={ | |
"user_id": user_id, | |
"evaluation_key": serialized_evaluation_keys_shorten_hex, | |
"evaluation_key_size": f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" | |
} | |
) | |
def run_dev_script(): | |
""" | |
Endpoint to execute the dev.py script to generate deployment files. | |
Returns: | |
JSONResponse: Success message or error details. | |
""" | |
try: | |
# Define the path to dev.py | |
dev_script_path = Path(__file__).parent / "dev.py" | |
# Execute the dev.py script | |
result = subprocess.run( | |
["python", str(dev_script_path)], | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
# Return success message with output | |
return JSONResponse( | |
content={"message": "dev.py executed successfully!", "output": result.stdout} | |
) | |
except subprocess.CalledProcessError as e: | |
# Return error message in case of failure | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Failed to execute dev.py", "details": e.stderr} | |
) | |