Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from backend.config import ( | |
| ABS_DATASET_DOMAIN, | |
| get_dataset_config, | |
| get_datasets, | |
| ) | |
| from backend.descriptions import ( | |
| DATASET_DESCRIPTIONS, | |
| DESCRIPTIONS, | |
| METRIC_DESCRIPTIONS, | |
| MODEL_DESCRIPTIONS, | |
| ) | |
| from backend.examples import ( | |
| get_examples_tab, | |
| ) | |
| from flask import Flask, Response, send_from_directory, request | |
| from flask_cors import CORS | |
| import os | |
| import logging | |
| import pandas as pd | |
| import json | |
| from io import StringIO | |
| from tools import ( | |
| get_leaderboard_filters, | |
| get_old_format_dataframe, | |
| ) # Import your function | |
| import typing as tp | |
| import requests | |
| from urllib.parse import unquote | |
| import mimetypes | |
| logger = logging.getLogger(__name__) | |
| if not logger.hasHandlers(): | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) | |
| logger.addHandler(handler) | |
| logger.setLevel(logging.INFO) | |
| logger.warning("Starting the Flask app...") | |
| app = Flask(__name__, static_folder="../frontend/dist", static_url_path="") | |
| CORS(app) | |
| def index(): | |
| logger.warning("Serving index.html") | |
| return send_from_directory(app.static_folder, "index.html") | |
| def datasets(): | |
| """ | |
| Returns the dataset configs grouped by audio / image / video. | |
| """ | |
| return Response(json.dumps(get_datasets()), mimetype="application/json") | |
| def data_files(dataset_name): | |
| """ | |
| Serves csv files from S3 or locally based on config | |
| """ | |
| # Get dataset_type from query params | |
| dataset_type = request.args.get("dataset_type") | |
| if not dataset_type: | |
| logger.error("No dataset_type provided in query parameters.") | |
| return "Dataset type not specified", 400 | |
| dataset_config = get_dataset_config(dataset_name) | |
| file_path = ( | |
| os.path.join(dataset_config["path"], dataset_name) + f"_{dataset_type}.csv" | |
| ) | |
| logger.info(f"Looking for dataset file: {file_path}") | |
| try: | |
| df = pd.read_csv(file_path) | |
| logger.info(f"Processing dataset: {dataset_name}") | |
| config = get_dataset_config(dataset_name) | |
| if dataset_type == "benchmark": | |
| return get_leaderboard(config, df) | |
| elif dataset_type == "attacks_variations": | |
| return get_chart(config, df) | |
| except: | |
| logger.error(f"Failed to fetch file: {file_path}") | |
| return "File not found", 404 | |
| def serve_file_path(file_path): | |
| """ | |
| Serves files from S3 or locally based on config | |
| """ | |
| # Get the absolute path to the file | |
| abs_path = file_path | |
| logger.info(f"Looking for file: {abs_path}") | |
| try: | |
| with open(abs_path, "rb") as f: | |
| content = f.read() | |
| return Response(content, mimetype="application/octet-stream") | |
| except FileNotFoundError: | |
| logger.error(f"Failed to fetch file: {abs_path}") | |
| return "File not found", 404 | |
| def example_files(type): | |
| """ | |
| Serve example files from S3 or locally based on config | |
| """ | |
| result = get_examples_tab(type) | |
| return Response(json.dumps(result), mimetype="application/json") | |
| def descriptions(): | |
| """ | |
| Serve descriptions and model descriptions from descriptions.py | |
| """ | |
| return Response( | |
| json.dumps( | |
| { | |
| "descriptions": DESCRIPTIONS, | |
| "metric_descriptions": METRIC_DESCRIPTIONS, | |
| "model_descriptions": MODEL_DESCRIPTIONS, | |
| "dataset_descriptions": DATASET_DESCRIPTIONS, | |
| } | |
| ), | |
| mimetype="application/json", | |
| ) | |
| # Add a proxy endpoint to bypass CORS issues | |
| def proxy(url): | |
| """ | |
| Proxy endpoint to fetch remote files and serve them to the frontend. | |
| This helps bypass CORS restrictions on remote resources. | |
| """ | |
| try: | |
| # Decode the URL parameter | |
| url = unquote(url) | |
| # Make sure we're only proxying from trusted domains for security | |
| if not url.startswith(ABS_DATASET_DOMAIN): | |
| return {"error": "Only proxying from allowed domains is permitted"}, 403 | |
| if url.startswith("http://") or url.startswith("https://"): | |
| response = requests.get(url, stream=True) | |
| if response.status_code != 200: | |
| return {"error": f"Failed to fetch from {url}"}, response.status_code | |
| # Create a Flask Response with the same content type as the original | |
| excluded_headers = [ | |
| "content-encoding", | |
| "content-length", | |
| "transfer-encoding", | |
| "connection", | |
| ] | |
| headers = { | |
| name: value | |
| for name, value in response.headers.items() | |
| if name.lower() not in excluded_headers | |
| } | |
| # Add CORS headers | |
| headers["Access-Control-Allow-Origin"] = "*" | |
| return Response(response.content, response.status_code, headers) | |
| else: | |
| # Serve a local file if the URL is not a network resource | |
| local_path = url | |
| if not os.path.exists(local_path): | |
| return {"error": f"Local file not found: {local_path}"}, 404 | |
| with open(local_path, "rb") as f: | |
| content = f.read() | |
| # Guess content type based on file extension | |
| mime_type, _ = mimetypes.guess_type(local_path) | |
| headers = {"Access-Control-Allow-Origin": "*"} | |
| return Response( | |
| content, | |
| mimetype=mime_type or "application/octet-stream", | |
| headers=headers, | |
| ) | |
| except Exception as e: | |
| return {"error": str(e)}, 500 | |
| def get_leaderboard(config, df): | |
| # Determine file type and handle accordingly | |
| logger.warning(f"Processing dataset with config: {config}") | |
| # This part adds on all the columns | |
| df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"]) | |
| groups, default_selection = get_leaderboard_filters(df, config["categories"]) | |
| # Replace NaN values with None for JSON serialization | |
| df = df.fillna(value="NaN") | |
| # Transpose the DataFrame so each column becomes a row and column is the model | |
| df = df.set_index("model").T.reset_index() | |
| df = df.rename(columns={"index": "metric"}) | |
| # Convert DataFrame to JSON | |
| result = { | |
| "groups": {group: list(metrics) for group, metrics in groups.items()}, | |
| "default_selected_metrics": list(default_selection), | |
| "rows": df.to_dict(orient="records"), | |
| } | |
| return Response(json.dumps(result), mimetype="application/json") | |
| def get_chart(config, df): | |
| # This function should return the chart data based on the DataFrame | |
| # For now, we will just return a placeholder response | |
| # Replace NaN values with None for JSON serialization | |
| attacks_plot_metrics = [ | |
| "bit_acc", | |
| "log10_p_value", | |
| "TPR", | |
| "FPR", | |
| "watermark_det_score", | |
| ] | |
| df = df.fillna(value="NaN") | |
| chart_data = { | |
| "metrics": attacks_plot_metrics, | |
| "attacks_with_variations": config["attacks_with_variations"], | |
| "all_attacks_df": df.to_dict(orient="records"), | |
| } | |
| return Response(json.dumps(chart_data), mimetype="application/json") | |
| def not_found(e): | |
| # Serve index.html for any 404 (SPA fallback) | |
| return send_from_directory(app.static_folder, "index.html") | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True) | |