File size: 4,415 Bytes
473311a
 
 
6e58aae
a96c23c
eab7cdf
6e58aae
 
473311a
6e58aae
eab7cdf
473311a
6e58aae
99ef832
eab7cdf
6e58aae
eab7cdf
 
6e58aae
 
 
eab7cdf
6e58aae
 
99ef832
eab7cdf
473311a
 
6e58aae
 
 
 
 
 
 
 
473311a
6e58aae
 
 
99ef832
6e58aae
 
 
 
 
 
 
eab7cdf
6e58aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473311a
6e58aae
 
 
 
 
 
eab7cdf
6e58aae
 
 
eab7cdf
6e58aae
 
 
 
 
 
 
 
 
 
 
 
eab7cdf
6e58aae
99ef832
6e58aae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import torch
from PIL import Image
import io
from peft import PeftModel
from unsloth import FastVisionModel
import tempfile
import os

# App title and description
st.set_page_config(page_title="Deepfake Analyzer", layout="wide")
st.title("Deepfake Image Analyzer")
st.markdown("Upload an image to analyze it for potential deepfake manipulation")

# Function to fix cross-attention masks
def fix_cross_attention_mask(inputs):
    if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
        batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
        visual_features = 6404  # Critical dimension
        new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
                            device=inputs['cross_attention_mask'].device)
        inputs['cross_attention_mask'] = new_mask
        st.success("Fixed cross-attention mask dimensions")
    return inputs

# Load model function
@st.cache_resource
def load_model():
    with st.spinner("Loading model... This may take a minute or two..."):
        try:
            # Load base model and tokenizer using Unsloth
            base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
            model, tokenizer = FastVisionModel.from_pretrained(
                base_model_id,
                load_in_4bit=True,
            )

            # Load the adapter
            adapter_id = "saakshigupta/deepfake-explainer-1"
            model = PeftModel.from_pretrained(model, adapter_id)

            # Set to inference mode
            FastVisionModel.for_inference(model)
            
            return model, tokenizer
        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            return None, None

# Analyze image function
def analyze_image(image, question, model, tokenizer):
    # Format the message
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": question}
        ]}
    ]

    # Apply chat template
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

    # Process with image
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt",
    ).to(model.device)

    # Fix cross-attention mask if needed
    inputs = fix_cross_attention_mask(inputs)

    # Generate response
    with st.spinner("Analyzing image... (this may take a moment)"):
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                use_cache=True,
                temperature=0.7,
                top_p=0.9
            )

        # Decode the output
        response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Try to extract just the model's response (after the prompt)
        if question in response:
            result = response.split(question)[-1].strip()
        else:
            result = response
            
        return result

# Main app
def main():
    # Load model
    model, tokenizer = load_model()
    
    if model is not None and tokenizer is not None:
        st.success("✅ Model loaded successfully! You can now analyze images.")
        
        # Image upload section
        st.subheader("Upload an Image")
        uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
        
        # Default question with option to customize
        default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations."
        question = st.text_area("Question/Prompt:", value=default_question, height=100)
        
        if uploaded_file is not None:
            # Display the uploaded image
            image = Image.open(uploaded_file).convert("RGB")
            st.image(image, caption="Uploaded Image", use_column_width=True)
            
            # Analyze button
            if st.button("Analyze Image"):
                result = analyze_image(image, question, model, tokenizer)
                
                # Display results
                st.subheader("Analysis Results")
                st.markdown(result)
    else:
        st.warning("Failed to load the model. Please check the console for errors.")

if __name__ == "__main__":
    main()