Spaces:
Sleeping
Sleeping
File size: 5,898 Bytes
15d2ecf 6614d86 2cd4845 6614d86 15d2ecf 6614d86 8739835 6614d86 15d2ecf 6614d86 c37dcff f365c3f f77b776 f365c3f 47e052c f365c3f 47e052c f365c3f d899012 6614d86 1ae6640 0d7c17f 863acdd 5f6dfb8 bc22d0b 98066c4 7a90a04 6c8aab2 3e24d5f 0d7c17f 6614d86 02a3276 6614d86 15d2ecf 6614d86 15d2ecf 6614d86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from io import BytesIO
import streamlit as st
import base64
from transformers import AutoModel, AutoTokenizer
from graphviz import Digraph
import json
def display_tree(output):
size = str(int(len(output))) + ',5'
dpi = '300'
format = 'svg'
print(size, dpi)
# Initialize Digraph object
dot = Digraph(engine='dot', format=format)
dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi)
# Add nodes and edges
for i,word_info in enumerate(output):
word = word_info['word'] # Prepare word for RTL display
head_idx = word_info['dep_head_idx']
dep_func = word_info['dep_func']
dot.node(str(i), word)
# Create an invisible edge from the previous word to this one to enforce order
if i > 0:
dot.edge(str(i), str(i - 1), style='invis')
if head_idx != -1:
dot.edge(str(head_idx), str(i), label=dep_func, constraint='False')
# Render the Digraph object
dot.render('syntax_tree', format=format, cleanup=True)
# Display the image in a scrollable container
st.markdown(
f"""
<div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw">
<img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}"
style="display: block; margin: auto; max-height: 240px;">
</div>
""", unsafe_allow_html=True)
#st.image('syntax_tree.' + format, use_column_width=True)
def display_download(disp_string):
to_download = BytesIO(disp_string.encode())
st.download_button(label="⬇️ Download text file",
data=to_download,
file_name="parsed_output.txt",
mime="text/plain")
# Streamlit app title
st.title('DictaBERT-Joint Visualizer')
# Load Hugging Face token
hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets
# Authenticate and load model
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token)
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True)
model.eval()
# Checkbox for the compute_mst parameter
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True)
output_style = st.selectbox(
'Output Style: ',
('JSON', 'UD', 'IAHLT_UD'), index=1).lower()
# User input
sentence = st.text_input('Enter a sentence to analyze:')
if sentence:
# Display the input sentence
st.text(sentence)
# Model prediction
output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0]
if output_style == 'ud' or output_style == 'iahlt_ud':
ud_output = output
# convert to tree format of [dict(word, dep_head_idx, dep_func)]
tree = []
for l in ud_output[2:]:
parts = l.split('\t')
if '-' in parts[0]: continue
tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7]))
display_tree(tree)
display_download('\n'.join(ud_output))
# Construct the table as a Markdown string
table_md = "<div dir='rtl' style='text-align: right;'>\n\n" # Start with RTL div
st.markdown("""<style>
.google-translate-place {
width: 256px;
height: 128px;
}
.google-translate-crop {
width: 256px;
height: 128px;
overflow: scroll;
position: absolute;
}
.google-translate {
transform: scale(0.75);
transform-origin: 180px 200px;
position: relative;
left: -200px; top: -180px;
width: 2560px; height: 5120px;
position: absolute;
}
</style>""", unsafe_allow_html=True)
# Add the UD header lines
table_md += "##" + ud_output[0] + "\n"
table_md += "##" + ud_output[1] + "\n"
# Table header
table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n"
# Table alignment
table_md += "| " + " | ".join(["---"]*10) + " |\n"
for line in ud_output[2:]:
# Each UD line as a table row
cells = line.replace('_', '\\_').replace('|', '|').replace(':', ':').split('\t')
wrd = cells[2]
if wrd != "\_":
cells[2] = "<div class='google-translate-place'><div class='google-translate-crop'><iframe class='google-translate' src='https://www.google.com/search?igu=1&q=" + wrd + "+in+English+google+translate&authuser=0&hl=en-US' width='256' height='128'></iframe></div></div><br/>"
cells[2] += "<iframe src='https://books.google.com/ngrams/interactive_chart?content=" + wrd + "_*&year_start=1800&year_end=2022&corpus=iw&smoothing=50' width='256' height='128'></iframe><br/>"
cells[2] += "<iframe src='https://freeali.se/freealise/translate/loader.htm?q=" + wrd + "&a=conj' width='256' height='128'></iframe><br/>"
cells[2] += "<iframe src='https://freeali.se/freealise/translate/loader.htm?q=" + wrd + "&a=def' width='256' height='128'></iframe><br/>"
cells[2] += "<a href='https://dict.com/hebrew-english/" + wrd + "' target='_blank'>" + wrd + "</a>"
table_md += "| " + " | ".join(cells) + " |\n"
table_md += "</div>" # Close the RTL div
print(table_md)
# Display the table using a single markdown call
st.markdown(table_md, unsafe_allow_html=True)
else:
# display the tree
tree = [w['syntax'] for w in output['tokens']]
display_tree(tree)
json_output = json.dumps(output, ensure_ascii=False, indent=2)
display_download(json_output)
# and the full json
st.markdown("```json\n" + json_output + "\n```")
|