File size: 3,714 Bytes
ff13394
 
 
 
 
 
 
7e1e741
9a5bfef
 
 
ff13394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e1e741
ff13394
 
9b256ac
7e1e741
ff13394
7e1e741
9b256ac
7e1e741
ff13394
 
c176b63
ff13394
 
 
 
 
 
7e1e741
 
c176b63
7e1e741
 
 
 
ff13394
7e1e741
 
 
ff13394
 
 
 
 
 
 
9b256ac
 
 
 
7e1e741
9b256ac
 
 
 
ff13394
 
 
 
 
 
 
 
 
7e1e741
 
 
ff13394
 
c176b63
ff13394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import io
import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import os

token = os.getenv("HF_TOKEN")

# Define the MIDM model
class MIDM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MIDM, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)

# Load models once when the app starts to avoid reloading for each request
stable_diff_pipe = None
model = None

def load_models(model_name="CompVis/stable-diffusion-v1-4"):
    global stable_diff_pipe, model
    
    # Load Stable Diffusion model pipeline
    stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
    stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize MIDM model
    input_dim = 10  
    hidden_dim = 64
    output_dim = 1
    model = MIDM(input_dim, hidden_dim, output_dim)
    
    model.eval()

# Function to extract features from the image using Stable Diffusion
def extract_image_features(image):
    #Extracts image features using the Stable Diffusion pipeline.
    # Preprocess the image and get the feature vector
    image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
    
    # Generate the image embedding using the model
    with torch.no_grad():
        generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean

    return generated_features

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/check-membership', methods=['POST'])
def check_membership():
    # Get the model name from the request
    model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4')
    
    # Ensure models are loaded with the selected model
    if stable_diff_pipe is None or model is None:
        load_models(model_name)
    elif stable_diff_pipe.name_or_path != model_name:
        # Reload the model if a different one is selected
        load_models(model_name)
        
    if 'image' not in request.files:
        return jsonify({'error': 'No image found in request'}), 400
    
    try:
        # Get the image from the request
        file = request.files['image']
        image_bytes = file.read()
        image = Image.open(io.BytesIO(image_bytes))

        # Get image features using Stable Diffusion
        image_features = extract_image_features(image)
        
        # Preprocess the features for MIDM model
        processed_features = image_features.reshape(1, -1)[:, :10]  # Select first 10 features
        
        # Perform inference
        with torch.no_grad():
            output = model(processed_features)
            probability = output.item()
            predicted = int(output > 0.5)
            
        return jsonify({
            'probability': probability,
            'predicted_class': predicted,
            'message': f"Predicted membership probability: {probability}",
            'is_in_training_data': "Likely" if predicted == 1 else "Unlikely"
        })
        
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port)