|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
import io |
|
from peft import PeftModel |
|
from unsloth import FastVisionModel |
|
import tempfile |
|
import os |
|
|
|
|
|
st.set_page_config(page_title="Deepfake Analyzer", layout="wide") |
|
st.title("Deepfake Image Analyzer") |
|
st.markdown("Upload an image to analyze it for potential deepfake manipulation") |
|
|
|
|
|
def fix_cross_attention_mask(inputs): |
|
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: |
|
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape |
|
visual_features = 6404 |
|
new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles), |
|
device=inputs['cross_attention_mask'].device) |
|
inputs['cross_attention_mask'] = new_mask |
|
st.success("Fixed cross-attention mask dimensions") |
|
return inputs |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
with st.spinner("Loading model... This may take a minute or two..."): |
|
try: |
|
|
|
base_model_id = "unsloth/llama-3.2-11b-vision-instruct" |
|
model, tokenizer = FastVisionModel.from_pretrained( |
|
base_model_id, |
|
load_in_4bit=True, |
|
) |
|
|
|
|
|
adapter_id = "saakshigupta/deepfake-explainer-1" |
|
model = PeftModel.from_pretrained(model, adapter_id) |
|
|
|
|
|
FastVisionModel.for_inference(model) |
|
|
|
return model, tokenizer |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
return None, None |
|
|
|
|
|
def analyze_image(image, question, model, tokenizer): |
|
|
|
messages = [ |
|
{"role": "user", "content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": question} |
|
]} |
|
] |
|
|
|
|
|
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True) |
|
|
|
|
|
inputs = tokenizer( |
|
image, |
|
input_text, |
|
add_special_tokens=False, |
|
return_tensors="pt", |
|
).to(model.device) |
|
|
|
|
|
inputs = fix_cross_attention_mask(inputs) |
|
|
|
|
|
with st.spinner("Analyzing image... (this may take a moment)"): |
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
use_cache=True, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if question in response: |
|
result = response.split(question)[-1].strip() |
|
else: |
|
result = response |
|
|
|
return result |
|
|
|
|
|
def main(): |
|
|
|
model, tokenizer = load_model() |
|
|
|
if model is not None and tokenizer is not None: |
|
st.success("β
Model loaded successfully! You can now analyze images.") |
|
|
|
|
|
st.subheader("Upload an Image") |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations." |
|
question = st.text_area("Question/Prompt:", value=default_question, height=100) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
if st.button("Analyze Image"): |
|
result = analyze_image(image, question, model, tokenizer) |
|
|
|
|
|
st.subheader("Analysis Results") |
|
st.markdown(result) |
|
else: |
|
st.warning("Failed to load the model. Please check the console for errors.") |
|
|
|
if __name__ == "__main__": |
|
main() |