File size: 1,560 Bytes
5457740 fa35904 5457740 fa35904 21fe83f fa35904 5457740 fa35904 5457740 fa35904 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import matplotlib.pyplot as plt
import seaborn as sns
# App Title and Styling
st.set_page_config(page_title="Transflower ๐ธ", layout="centered")
st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower ๐ธ</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)
# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
# Text Input
input_text = st.text_area("๐ผ Enter text to summarize or visualize:", height=150)
# When user clicks the button
if st.button("โจ Visualize Transformer Magic โจ") and input_text:
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
st.success("๐ธ Summary:")
st.markdown(f"`{decoded}`")
st.markdown("๐ท Attention Visualization (Encoder Self-Attention)")
# Extract attentions
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attentions = outputs.encoder_attentions[0][0]
fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(attentions.mean(
|