rahideer commited on
Commit
d820d3f
ยท
verified ยท
1 Parent(s): 1ac1b38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -28
app.py CHANGED
@@ -1,38 +1,49 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
 
6
 
7
- # App Title and Styling
8
- st.set_page_config(page_title="Transflower ๐ŸŒธ", layout="centered")
9
- st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower ๐ŸŒธ</h1>", unsafe_allow_html=True)
10
- st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)
 
 
 
 
11
 
12
  # Load model and tokenizer
13
- model_name = "sshleifer/distilbart-cnn-12-6"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
16
 
17
- # Text Input
18
- input_text = st.text_area("๐ŸŒผ Enter text to summarize or visualize:", height=150)
19
-
20
- # When user clicks the button
21
- if st.button("โœจ Visualize Transformer Magic โœจ") and input_text:
22
- inputs = tokenizer(input_text, return_tensors="pt")
23
- with torch.no_grad():
24
- output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
25
-
26
- decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
27
- st.success("๐ŸŒธ Summary:")
28
- st.markdown(f"`{decoded}`")
29
-
30
- st.markdown("๐ŸŒท Attention Visualization (Encoder Self-Attention)")
31
-
32
- # Extract attentions
33
- with torch.no_grad():
34
- outputs = model(**inputs, output_attentions=True)
35
- attentions = outputs.encoder_attentions[0][0]
36
-
37
- fig, ax = plt.subplots(figsize=(10, 6))
38
- sns.heatmap(attentions.mean(
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
+ import numpy as np
7
 
8
+ # Page setup
9
+ st.set_page_config(page_title="Transflower ๐ŸŒธ", page_icon="๐ŸŒผ", layout="centered")
10
+
11
+ st.markdown(
12
+ "<h1 style='text-align: center; color: pink;'>๐ŸŒธ Transflower ๐ŸŒธ</h1>"
13
+ "<p style='text-align: center; color: gray;'>A girly and cute app to visualize Transformer magic</p>",
14
+ unsafe_allow_html=True,
15
+ )
16
 
17
  # Load model and tokenizer
18
+ model_name = "t5-small"
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
21
 
22
+ # Input area
23
+ user_input = st.text_area("๐ŸŒผ Enter text to summarize or visualize:", height=200)
24
+
25
+ if st.button("โœจ Visualize Transformer Magic โœจ"):
26
+ if not user_input.strip():
27
+ st.warning("Please enter some text to visualize.")
28
+ else:
29
+ # Prepare input
30
+ input_ids = tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True)
31
+
32
+ # Forward pass with attentions
33
+ with torch.no_grad():
34
+ outputs = model.generate(input_ids, output_attentions=True, return_dict_in_generate=True, output_scores=True)
35
+ decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
36
+
37
+ st.subheader("๐ŸŒธ Summary:")
38
+ st.success(decoded)
39
+
40
+ # Visualization
41
+ st.subheader("๐Ÿ’– Attention Heatmap:")
42
+ fig, ax = plt.subplots(figsize=(10, 5))
43
+
44
+ # Get decoder self-attention from the last layer
45
+ attention_data = outputs.attentions[-1] # List of attention tensors from each layer
46
+ avg_attention = attention_data[0].mean(dim=0).squeeze().detach().numpy() # mean over heads
47
+
48
+ sns.heatmap(avg_attention, cmap="coolwarm", ax=ax)
49
+ st.pyplot(fig)