|
import streamlit as st |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
|
|
st.set_page_config(page_title="Transflower πΈ", layout="centered") |
|
st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower πΈ</h1>", unsafe_allow_html=True) |
|
st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True) |
|
|
|
|
|
model_name = "sshleifer/distilbart-cnn-12-6" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True) |
|
|
|
|
|
input_text = st.text_area("πΌ Enter text to summarize or visualize:", height=150) |
|
|
|
|
|
if st.button("β¨ Visualize Transformer Magic β¨") and input_text: |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
with torch.no_grad(): |
|
output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True) |
|
|
|
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True) |
|
st.success("πΈ Summary:") |
|
st.markdown(f"`{decoded}`") |
|
|
|
st.markdown("π· Attention Visualization (Encoder Self-Attention)") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_attentions=True) |
|
attentions = outputs.encoder_attentions[0][0] |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
sns.heatmap(attentions.mean( |
|
|