File size: 1,560 Bytes
5457740
fa35904
5457740
fa35904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21fe83f
fa35904
 
 
5457740
fa35904
 
 
 
 
 
5457740
fa35904
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# App Title and Styling
st.set_page_config(page_title="Transflower ๐ŸŒธ", layout="centered")
st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower ๐ŸŒธ</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)

# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)

# Text Input
input_text = st.text_area("๐ŸŒผ Enter text to summarize or visualize:", height=150)

# When user clicks the button
if st.button("โœจ Visualize Transformer Magic โœจ") and input_text:
    inputs = tokenizer(input_text, return_tensors="pt")
    with torch.no_grad():
        output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
    
    decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
    st.success("๐ŸŒธ Summary:")
    st.markdown(f"`{decoded}`")

    st.markdown("๐ŸŒท Attention Visualization (Encoder Self-Attention)")
    
    # Extract attentions
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
        attentions = outputs.encoder_attentions[0][0]

    fig, ax = plt.subplots(figsize=(10, 6))
    sns.heatmap(attentions.mean(