Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, Form, UploadFile | |
from fastapi.responses import JSONResponse, Response | |
from concrete.ml.deployment import FHEModelServer | |
import numpy as np | |
from concrete.ml.deployment import FHEModelClient | |
import subprocess | |
from pathlib import Path | |
from utils import ( | |
CLIENT_DIR, | |
CURRENT_DIR, | |
DEPLOYMENT_DIR, | |
SERVER_DIR, | |
INPUT_BROWSER_LIMIT, | |
KEYS_DIR, | |
SERVER_URL, | |
TARGET_COLUMNS, | |
TRAINING_FILENAME, | |
clean_directory, | |
get_disease_name, | |
load_data, | |
pretty_print, | |
) | |
import time | |
from typing import List | |
# Load the FHE server | |
FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR) | |
app = FastAPI() | |
def greet_json(): | |
return {"Hello": "World!"} | |
def root(): | |
""" | |
Root endpoint of the health prediction API. | |
Returns: | |
dict: The welcome message. | |
""" | |
return {"message": "Welcome to your disease prediction with FHE!"} | |
def send_input( | |
user_id: str = Form(), | |
files: List[UploadFile] = File(), | |
): | |
"""Send the inputs to the server.""" | |
print("\nSend the data to the server ............\n") | |
# Receive the Client's files (Evaluation key + Encrypted symptoms) | |
evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key" | |
encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input" | |
# Save the files using the above paths | |
with encrypted_input_path.open("wb") as encrypted_input, evaluation_key_path.open( | |
"wb" | |
) as evaluation_key: | |
encrypted_input.write(files[0].file.read()) | |
evaluation_key.write(files[1].file.read()) | |
def run_fhe( | |
user_id: str = Form(), | |
): | |
"""Inference in FHE.""" | |
print("\nRun in FHE in the server ............\n") | |
evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key" | |
encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input" | |
# Read the files (Evaluation key + Encrypted symptoms) using the above paths | |
with encrypted_input_path.open("rb") as encrypted_output_file, evaluation_key_path.open( | |
"rb" | |
) as evaluation_key_file: | |
encrypted_output = encrypted_output_file.read() | |
evaluation_key = evaluation_key_file.read() | |
# Run the FHE execution | |
start = time.time() | |
encrypted_output = FHE_SERVER.run(encrypted_output, evaluation_key) | |
assert isinstance(encrypted_output, bytes) | |
fhe_execution_time = round(time.time() - start, 2) | |
# Retrieve the encrypted output path | |
encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output" | |
# Write the file using the above path | |
with encrypted_output_path.open("wb") as f: | |
f.write(encrypted_output) | |
return JSONResponse(content=fhe_execution_time) | |
def get_output(user_id: str = Form()): | |
"""Retrieve the encrypted output from the server.""" | |
print("\nGet the output from the server ............\n") | |
# Path where the encrypted output is saved | |
encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output" | |
# Read the file using the above path | |
with encrypted_output_path.open("rb") as f: | |
encrypted_output = f.read() | |
time.sleep(1) | |
# Send the encrypted output | |
return Response(encrypted_output) | |
def generate_keys(user_symptoms: List[str]): | |
""" | |
Endpoint to generate keys based on user symptoms. | |
Args: | |
user_symptoms (List[str]): The list of user symptoms. | |
Returns: | |
JSONResponse: A response containing the generated keys and user ID. | |
""" | |
def is_none(obj): | |
return obj is None or (obj is not None and len(obj) == 0) | |
# Call the key generation function | |
clean_directory() | |
if is_none(user_symptoms): | |
return JSONResponse( | |
status_code=400, content={"error": "Please submit your symptoms first."} | |
) | |
# Generate a random user ID | |
user_id = np.random.randint(0, 2**32) | |
print(f"Your user ID is: {user_id}....") | |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") | |
client.load() | |
# Creates the private and evaluation keys on the client side | |
client.generate_private_and_evaluation_keys() | |
# Get the serialized evaluation keys | |
serialized_evaluation_keys = client.get_serialized_evaluation_keys() | |
assert isinstance(serialized_evaluation_keys, bytes) | |
# Save the evaluation key | |
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" | |
with evaluation_key_path.open("wb") as f: | |
f.write(serialized_evaluation_keys) | |
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] | |
return JSONResponse( | |
content={ | |
"user_id": user_id, | |
"evaluation_key": serialized_evaluation_keys_shorten_hex, | |
"evaluation_key_size": f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" | |
} | |
) | |
def run_dev_script(): | |
""" | |
Endpoint to execute the dev.py script to generate deployment files. | |
Returns: | |
JSONResponse: Success message or error details. | |
""" | |
try: | |
# Define the path to dev.py | |
dev_script_path = Path(__file__).parent / "dev.py" | |
# Execute the dev.py script | |
result = subprocess.run( | |
["python", str(dev_script_path)], | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
# Return success message with output | |
return JSONResponse( | |
content={"message": "dev.py executed successfully!", "output": result.stdout} | |
) | |
except subprocess.CalledProcessError as e: | |
# Return error message in case of failure | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Failed to execute dev.py", "details": e.stderr} | |
) |