LLM / app.py
rahideer's picture
Update app.py
fa35904 verified
raw
history blame
1.56 kB
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import matplotlib.pyplot as plt
import seaborn as sns
# App Title and Styling
st.set_page_config(page_title="Transflower 🌸", layout="centered")
st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower 🌸</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)
# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
# Text Input
input_text = st.text_area("🌼 Enter text to summarize or visualize:", height=150)
# When user clicks the button
if st.button("✨ Visualize Transformer Magic ✨") and input_text:
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
st.success("🌸 Summary:")
st.markdown(f"`{decoded}`")
st.markdown("🌷 Attention Visualization (Encoder Self-Attention)")
# Extract attentions
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attentions = outputs.encoder_attentions[0][0]
fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(attentions.mean(