saakshigupta commited on
Commit
473311a
·
verified ·
1 Parent(s): e3ee4e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ from peft import PeftModel
8
+
9
+ # Page config
10
+ st.set_page_config(
11
+ page_title="Deepfake Explainer",
12
+ page_icon="🔍",
13
+ layout="wide"
14
+ )
15
+
16
+ # App title and description
17
+ st.title("Deepfake Image Analyzer")
18
+ st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
+
20
+ @st.cache_resource
21
+ def load_model():
22
+ """Load model and processor (cached to avoid reloading)"""
23
+ # Load base model
24
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
25
+ processor = AutoProcessor.from_pretrained(base_model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ base_model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16
30
+ )
31
+
32
+ # Load adapter
33
+ adapter_id = "saakshigupta/deepfake-explainer-1"
34
+ model = PeftModel.from_pretrained(model, adapter_id)
35
+
36
+ return model, processor
37
+
38
+ # Function to fix cross-attention masks
39
+ def fix_processor_outputs(inputs):
40
+ if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
41
+ batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
42
+ visual_features = 6404 # The exact dimension we fixed in training
43
+ new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
44
+ device=inputs['cross_attention_mask'].device)
45
+ inputs['cross_attention_mask'] = new_mask
46
+ st.write("✅ Fixed cross-attention mask dimensions")
47
+ return inputs
48
+
49
+ # Load model on first run
50
+ with st.spinner("Loading model... this may take a minute."):
51
+ model, processor = load_model()
52
+ st.success("Model loaded successfully!")
53
+
54
+ # Create sidebar with options
55
+ with st.sidebar:
56
+ st.header("Options")
57
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
58
+ help="Higher values make output more random, lower values more deterministic")
59
+ max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
60
+
61
+ custom_prompt = st.text_area(
62
+ "Custom instruction (optional)",
63
+ value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
64
+ height=100
65
+ )
66
+
67
+ st.markdown("### About")
68
+ st.markdown("This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.")
69
+ st.markdown("Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)")
70
+
71
+ # Main content area - file uploader
72
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
73
+
74
+ if uploaded_file is not None:
75
+ # Display the image
76
+ image = Image.open(uploaded_file).convert('RGB')
77
+ st.image(image, caption="Uploaded Image", use_column_width=True)
78
+
79
+ # Analyze button
80
+ if st.button("Analyze Image"):
81
+ with st.spinner("Analyzing the image..."):
82
+ # Process the image
83
+ inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
84
+
85
+ # Fix cross-attention mask
86
+ inputs = fix_processor_outputs(inputs)
87
+
88
+ # Move to device
89
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
90
+
91
+ # Generate the analysis
92
+ with torch.no_grad():
93
+ output_ids = model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_length,
96
+ temperature=temperature,
97
+ top_p=0.9
98
+ )
99
+
100
+ # Decode the output
101
+ response = processor.decode(output_ids[0], skip_special_tokens=True)
102
+
103
+ # Extract the actual response (removing the prompt)
104
+ if custom_prompt in response:
105
+ result = response.split(custom_prompt)[-1].strip()
106
+ else:
107
+ result = response
108
+
109
+ # Display result in a nice format
110
+ st.success("Analysis complete!")
111
+
112
+ # Show technical and non-technical explanations separately if they exist
113
+ if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
114
+ technical, non_technical = result.split("Non-Technical Explanation:")
115
+ technical = technical.replace("Technical Explanation:", "").strip()
116
+
117
+ col1, col2 = st.columns(2)
118
+ with col1:
119
+ st.subheader("Technical Analysis")
120
+ st.write(technical)
121
+
122
+ with col2:
123
+ st.subheader("Simple Explanation")
124
+ st.write(non_technical)
125
+ else:
126
+ st.subheader("Analysis Result")
127
+ st.write(result)
128
+ else:
129
+ st.info("Please upload an image to begin analysis")