rahideer commited on
Commit
44cb668
Β·
verified Β·
1 Parent(s): d820d3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -10,14 +10,15 @@ st.set_page_config(page_title="Transflower 🌸", page_icon="🌼", layout="cent
10
 
11
  st.markdown(
12
  "<h1 style='text-align: center; color: pink;'>🌸 Transflower 🌸</h1>"
13
- "<p style='text-align: center; color: gray;'>A girly and cute app to visualize Transformer magic</p>",
14
  unsafe_allow_html=True,
15
  )
16
 
17
  # Load model and tokenizer
18
  model_name = "t5-small"
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
 
21
 
22
  # Input area
23
  user_input = st.text_area("🌼 Enter text to summarize or visualize:", height=200)
@@ -26,24 +27,25 @@ if st.button("✨ Visualize Transformer Magic ✨"):
26
  if not user_input.strip():
27
  st.warning("Please enter some text to visualize.")
28
  else:
29
- # Prepare input
30
- input_ids = tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True)
31
-
32
- # Forward pass with attentions
33
  with torch.no_grad():
34
- outputs = model.generate(input_ids, output_attentions=True, return_dict_in_generate=True, output_scores=True)
35
- decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
36
 
37
- st.subheader("🌸 Summary:")
38
- st.success(decoded)
 
 
39
 
40
- # Visualization
41
- st.subheader("πŸ’– Attention Heatmap:")
42
- fig, ax = plt.subplots(figsize=(10, 5))
43
 
44
- # Get decoder self-attention from the last layer
45
- attention_data = outputs.attentions[-1] # List of attention tensors from each layer
46
- avg_attention = attention_data[0].mean(dim=0).squeeze().detach().numpy() # mean over heads
47
 
48
- sns.heatmap(avg_attention, cmap="coolwarm", ax=ax)
 
 
49
  st.pyplot(fig)
 
10
 
11
  st.markdown(
12
  "<h1 style='text-align: center; color: pink;'>🌸 Transflower 🌸</h1>"
13
+
14
  unsafe_allow_html=True,
15
  )
16
 
17
  # Load model and tokenizer
18
  model_name = "t5-small"
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+ model.eval()
22
 
23
  # Input area
24
  user_input = st.text_area("🌼 Enter text to summarize or visualize:", height=200)
 
27
  if not user_input.strip():
28
  st.warning("Please enter some text to visualize.")
29
  else:
30
+ # Encode input
31
+ inputs = tokenizer("summarize: " + user_input, return_tensors="pt", truncation=True)
32
+
33
+ # Forward pass manually to get attention
34
  with torch.no_grad():
35
+ encoder_outputs = model.encoder(**inputs, output_attentions=True, return_dict=True)
36
+ attention = encoder_outputs.attentions[-1][0].mean(dim=0).detach().numpy()
37
 
38
+ # Generate summary
39
+ with torch.no_grad():
40
+ summary_ids = model.generate(inputs["input_ids"], max_length=50)
41
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
42
 
43
+ st.subheader("🌸 Summary:")
44
+ st.success(summary)
 
45
 
46
+ st.subheader("πŸ’– Encoder Attention Heatmap:")
 
 
47
 
48
+ fig, ax = plt.subplots(figsize=(10, 6))
49
+ sns.heatmap(attention, cmap="YlGnBu", ax=ax)
50
+ ax.set_title("Encoder Self-Attention Heatmap πŸ’«")
51
  st.pyplot(fig)