saakshigupta's picture
Update app.py
6e58aae verified
raw
history blame
4.42 kB
import streamlit as st
import torch
from PIL import Image
import io
from peft import PeftModel
from unsloth import FastVisionModel
import tempfile
import os
# App title and description
st.set_page_config(page_title="Deepfake Analyzer", layout="wide")
st.title("Deepfake Image Analyzer")
st.markdown("Upload an image to analyze it for potential deepfake manipulation")
# Function to fix cross-attention masks
def fix_cross_attention_mask(inputs):
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
visual_features = 6404 # Critical dimension
new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
device=inputs['cross_attention_mask'].device)
inputs['cross_attention_mask'] = new_mask
st.success("Fixed cross-attention mask dimensions")
return inputs
# Load model function
@st.cache_resource
def load_model():
with st.spinner("Loading model... This may take a minute or two..."):
try:
# Load base model and tokenizer using Unsloth
base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
model, tokenizer = FastVisionModel.from_pretrained(
base_model_id,
load_in_4bit=True,
)
# Load the adapter
adapter_id = "saakshigupta/deepfake-explainer-1"
model = PeftModel.from_pretrained(model, adapter_id)
# Set to inference mode
FastVisionModel.for_inference(model)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
# Analyze image function
def analyze_image(image, question, model, tokenizer):
# Format the message
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question}
]}
]
# Apply chat template
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Process with image
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to(model.device)
# Fix cross-attention mask if needed
inputs = fix_cross_attention_mask(inputs)
# Generate response
with st.spinner("Analyzing image... (this may take a moment)"):
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
use_cache=True,
temperature=0.7,
top_p=0.9
)
# Decode the output
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Try to extract just the model's response (after the prompt)
if question in response:
result = response.split(question)[-1].strip()
else:
result = response
return result
# Main app
def main():
# Load model
model, tokenizer = load_model()
if model is not None and tokenizer is not None:
st.success("βœ… Model loaded successfully! You can now analyze images.")
# Image upload section
st.subheader("Upload an Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Default question with option to customize
default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations."
question = st.text_area("Question/Prompt:", value=default_question, height=100)
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Analyze button
if st.button("Analyze Image"):
result = analyze_image(image, question, model, tokenizer)
# Display results
st.subheader("Analysis Results")
st.markdown(result)
else:
st.warning("Failed to load the model. Please check the console for errors.")
if __name__ == "__main__":
main()