|
|
|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
import io |
|
from transformers import AutoProcessor, BitsAndBytesConfig, MllamaForCausalLM |
|
from peft import PeftModel |
|
|
|
|
|
st.set_page_config( |
|
page_title="Deepfake Explainer", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.title("Deepfake Image Analyzer") |
|
st.markdown("Upload an image to analyze it for possible deepfake manipulation") |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load model and processor (cached to avoid reloading)""" |
|
|
|
base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit" |
|
processor = AutoProcessor.from_pretrained(base_model_id) |
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
|
|
model = MllamaForCausalLM.from_pretrained( |
|
base_model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
quantization_config=quantization_config |
|
) |
|
|
|
|
|
adapter_id = "saakshigupta/deepfake-explainer-1" |
|
model = PeftModel.from_pretrained(model, adapter_id) |
|
|
|
return model, processor |
|
|
|
|
|
def fix_processor_outputs(inputs): |
|
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: |
|
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape |
|
visual_features = 6404 |
|
new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles), |
|
device=inputs['cross_attention_mask'].device) |
|
inputs['cross_attention_mask'] = new_mask |
|
st.write("β
Fixed cross-attention mask dimensions") |
|
return inputs |
|
|
|
|
|
with st.spinner("Loading model... this may take a minute."): |
|
model, processor = load_model() |
|
st.success("Model loaded successfully!") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Options") |
|
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, |
|
help="Higher values make output more random, lower values more deterministic") |
|
max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50) |
|
|
|
custom_prompt = st.text_area( |
|
"Custom instruction (optional)", |
|
value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.", |
|
height=100 |
|
) |
|
|
|
st.markdown("### About") |
|
st.markdown("This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.") |
|
st.markdown("Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file).convert('RGB') |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
if st.button("Analyze Image"): |
|
with st.spinner("Analyzing the image..."): |
|
|
|
inputs = processor(text=custom_prompt, images=image, return_tensors="pt") |
|
|
|
|
|
inputs = fix_processor_outputs(inputs) |
|
|
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
response = processor.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if custom_prompt in response: |
|
result = response.split(custom_prompt)[-1].strip() |
|
else: |
|
result = response |
|
|
|
|
|
st.success("Analysis complete!") |
|
|
|
|
|
if "Technical Explanation:" in result and "Non-Technical Explanation:" in result: |
|
technical, non_technical = result.split("Non-Technical Explanation:") |
|
technical = technical.replace("Technical Explanation:", "").strip() |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.subheader("Technical Analysis") |
|
st.write(technical) |
|
|
|
with col2: |
|
st.subheader("Simple Explanation") |
|
st.write(non_technical) |
|
else: |
|
st.subheader("Analysis Result") |
|
st.write(result) |
|
else: |
|
st.info("Please upload an image to begin analysis") |