|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
import gc |
|
from transformers import AutoProcessor |
|
from peft import PeftModel |
|
from unsloth import FastVisionModel |
|
|
|
|
|
st.set_page_config(page_title="Deepfake Analyzer", layout="wide") |
|
|
|
|
|
st.title("Deepfake Image Analyzer") |
|
st.markdown("This app analyzes images for signs of deepfake manipulation") |
|
|
|
|
|
def free_memory(): |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def fix_processor_outputs(inputs): |
|
"""Fix cross-attention mask dimensions if needed""" |
|
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: |
|
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape |
|
visual_features = 6404 |
|
new_mask = torch.ones( |
|
(batch_size, seq_len, visual_features, num_tiles), |
|
device=inputs['cross_attention_mask'].device |
|
) |
|
inputs['cross_attention_mask'] = new_mask |
|
return True, inputs |
|
return False, inputs |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load model using Unsloth approach (similar to Colab)""" |
|
try: |
|
base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(base_model_id) |
|
|
|
|
|
model, _ = FastVisionModel.from_pretrained( |
|
base_model_id, |
|
load_in_4bit=True, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
|
|
FastVisionModel.for_inference(model) |
|
|
|
|
|
adapter_id = "saakshigupta/deepfake-explainer-1" |
|
model = PeftModel.from_pretrained(model, adapter_id) |
|
|
|
return model, processor |
|
|
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.exception(e) |
|
return None, None |
|
|
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
temperature = st.slider("Temperature", 0.1, 1.0, 0.7, 0.1) |
|
max_length = st.slider("Max length", 100, 500, 300, 50) |
|
|
|
|
|
prompt = st.text_area( |
|
"Analysis instruction", |
|
value="Analyze this image and determine if it's a deepfake. Provide your reasoning.", |
|
height=100 |
|
) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
|
|
if st.button("1. Load Model"): |
|
with st.spinner("Loading model... (this may take a minute)"): |
|
model, processor = load_model() |
|
if model is not None and processor is not None: |
|
st.session_state['model'] = model |
|
st.session_state['processor'] = processor |
|
st.success("β Model loaded successfully!") |
|
else: |
|
st.error("Failed to load model") |
|
|
|
|
|
uploaded_file = st.file_uploader("2. Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert('RGB') |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
model_loaded = 'model' in st.session_state and st.session_state['model'] is not None |
|
|
|
if st.button("3. Analyze Image", disabled=not model_loaded): |
|
if not model_loaded: |
|
st.warning("Please load the model first") |
|
else: |
|
col2.subheader("Analysis Results") |
|
with col2.spinner("Analyzing image..."): |
|
try: |
|
|
|
model = st.session_state['model'] |
|
processor = st.session_state['processor'] |
|
|
|
|
|
messages = [ |
|
{"role": "user", "content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": prompt} |
|
]} |
|
] |
|
|
|
|
|
input_text = processor.tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = processor( |
|
images=image, |
|
text=input_text, |
|
add_special_tokens=False, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
|
|
fixed, inputs = fix_processor_outputs(inputs) |
|
if fixed: |
|
col2.info("Fixed cross-attention mask dimensions") |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
response = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
col2.success("Analysis complete!") |
|
col2.markdown(response) |
|
|
|
|
|
free_memory() |
|
|
|
except Exception as e: |
|
col2.error(f"Error analyzing image: {str(e)}") |
|
col2.exception(e) |
|
elif not model_loaded: |
|
st.info("Please load the model first (Step 1)") |
|
else: |
|
st.info("Please upload an image (Step 2)") |
|
|
|
with col2: |
|
if 'model' not in st.session_state: |
|
st.info("π Follow the steps on the left to analyze an image") |