Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spacy | |
import spacy_transformers | |
from huggingface_hub import snapshot_download | |
import os | |
from collections import Counter | |
# Download the model from Hugging Face Hub and load it. | |
model_repo = "hksung/ASC_tagger_v2" | |
model_path = snapshot_download(model_repo) # the model; public | |
nlp = spacy.load(os.path.join(model_path, 'model-best')) | |
# add a sentencizer (if not already present) | |
if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names: | |
nlp.add_pipe('sentencizer') | |
def get_highlighted_text(doc): | |
""" | |
Wraps detected ASCs in each sentence with a <span> tag that carries the entity tag | |
in the data-entity attribute. The final HTML output is prepended with a CSS block that | |
applies your desired styles, including a dark mode adjustment. | |
""" | |
highlighted_sentences = [] | |
for sent in doc.sents: | |
text = sent.text | |
# Find all entities completely within this sentence. | |
ents_in_sent = [ent for ent in doc.ents if ent.start_char >= sent.start_char and ent.end_char <= sent.end_char] | |
if ents_in_sent: | |
# Process entities in reverse order to preserve character offsets. | |
ents_in_sent = sorted(ents_in_sent, key=lambda x: x.start_char, reverse=True) | |
for ent in ents_in_sent: | |
ent_start = ent.start_char - sent.start_char | |
ent_end = ent.end_char - sent.start_char | |
# Wrap the entity text with a <span> that carries the class and the data-entity attribute. | |
text = ( | |
text[:ent_start] | |
+ f'<span class="entity" data-entity="{ent.label_}">' | |
+ text[ent_start:ent_end] | |
+ '</span>' | |
+ text[ent_end:] | |
) | |
highlighted_sentences.append(text) | |
result = "<br><br>".join(highlighted_sentences) | |
style = """ | |
<style> | |
body { | |
font-family: Arial, sans-serif; | |
margin: 0; | |
padding: 20px; | |
background-color: #f4f4f4; | |
color: #000; | |
} | |
.container { | |
max-width: 1000px; | |
margin: 0 auto; | |
padding: 20px; | |
background-color: white; | |
border-radius: 8px; | |
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | |
line-height: 2.2em; /* Increased line spacing */ | |
} | |
.entity { | |
display: inline-block; | |
border: none; | |
border-radius: 2px; | |
padding: 2px 5px; | |
margin: 0 4px; | |
position: relative; | |
white-space: nowrap; | |
line-height: 1.2; | |
font-size: 0.8em; | |
/* Ensure default text color */ | |
color: inherit; | |
} | |
/* Highlight background colors for entity types */ | |
.entity[data-entity="ATTR"] { background-color: #dbb6ab; } | |
.entity[data-entity="INTRAN_S"] { background-color: #e7957f; } | |
.entity[data-entity="INTRAN_MOT"] { background-color: #ebab22; } | |
.entity[data-entity="INTRAN_RES"] { background-color: #f095cc; } | |
.entity[data-entity="CAUS_MOT"] { background-color: #85a831; } | |
.entity[data-entity="TRAN_S"] { background-color: #a0d4f7; } | |
.entity[data-entity="TRAN_RES"] { background-color: #c7aefa; } | |
.entity[data-entity="DITRAN"] { background-color: #b3f0f7; } | |
.entity[data-entity="PASSIVE"] { background-color: #c3c0c0; } | |
.entity[data-entity=""] { background-color: #cccccc; } | |
/* Darker background colors for the entity label tooltips */ | |
.entity[data-entity="ATTR"]::after { background-color: #d29997; } | |
.entity[data-entity="INTRAN_S"]::after { background-color: #ec6161; } | |
.entity[data-entity="INTRAN_MOT"]::after { background-color: #eb9422; } | |
.entity[data-entity="INTRAN_RES"]::after { background-color: #be5791; } | |
.entity[data-entity="CAUS_MOT"]::after { background-color: #007030; } | |
.entity[data-entity="TRAN_S"]::after { background-color: #3085ce; } | |
.entity[data-entity="TRAN_RES"]::after { background-color: #8268cf; } | |
.entity[data-entity="DITRAN"]::after { background-color: #449cbb; } | |
.entity[data-entity="PASSIVE"]::after { background-color: #6b6b6b; } | |
.entity[data-entity=""]::after { background-color: #888888; } | |
/* Styling for the entity label tooltip */ | |
.entity::after { | |
content: attr(data-entity); | |
position: absolute; | |
bottom: -2em; | |
left: 0; | |
right: 0; | |
color: #fff; | |
font-size: 0.6em; | |
padding: 2px 4px; | |
border-radius: 2px; | |
text-align: center; | |
min-width: 60px; | |
white-space: nowrap; | |
} | |
/* Dark mode adjustments */ | |
@media (prefers-color-scheme: dark) { | |
body { | |
background-color: #181818; | |
color: #e0e0e0; | |
} | |
.container { | |
background-color: #282828; | |
box-shadow: 0 0 10px rgba(255, 255, 255, 0.1); | |
} | |
/* Ensure text inside entities is visible in dark mode */ | |
.entity { | |
color: #e0e0e0; | |
} | |
.entity::after { | |
color: #fff; | |
} | |
} | |
</style> | |
""" | |
return style + "<div class='container'>" + result + "</div>" | |
def process_text(input_text): | |
""" | |
Process input text to detect and tag ASCs. | |
Returns a tuple: (HTML string, total tag count, dictionary of tag counts). | |
""" | |
if not input_text.strip(): | |
return "No text provided. Please enter some text.", 0, {} | |
doc = nlp(input_text) | |
if not list(doc.sents): | |
return "Please enter at least one sentence.", 0, {} | |
if not doc.ents: | |
return "No ASCs were detected.", 0, {} | |
highlighted_html = get_highlighted_text(doc) | |
# Count each tag type using collections.Counter | |
tag_counter = Counter(ent.label_ for ent in doc.ents) | |
total_count = sum(tag_counter.values()) | |
return highlighted_html, total_count, dict(tag_counter) | |
sample_text = ( | |
"When, while the lovely valley teems with vapor around me, and the meridian sun strikes the upper surface of the impenetrable foliage of my trees, and but a few stray gleams steal into the inner sanctuary, I throw myself down among the tall grass by the trickling stream; and, as I lie close to the earth, a thousand unknown plants are noticed by me: when I hear the buzz of the little world among the stalks, and grow familiar with the countless indescribable forms of the insects and flies, then I feel the presence of the Almighty, who formed us in his own image, and the breath of that celestial force fills my lungs with an ineffable wonder, drawing my soul into a silent communion with the eternal rhythms of the earth." | |
) | |
def fill_sample_text(): | |
return sample_text | |
with gr.Blocks() as demo: | |
gr.Markdown("# ASC tagger demo") | |
gr.Markdown( | |
"Enter some text to have ASCs tagged. Use the button below to fill in sample text. " | |
"Learn more about the related works [here](https://github.com/LCR-ADS-Lab/ASC-Treebank)." | |
) | |
input_textbox = gr.Textbox(lines=5, label="Input text", placeholder="Enter text here...") | |
output_html = gr.HTML(label="Tagged text") | |
output_count = gr.Number(label="Number of ASCs detected", precision=0) | |
output_counts_by_type = gr.JSON(label="ASC counts by type") | |
tag_btn = gr.Button("Tag ASCs") | |
fill_btn = gr.Button("Sample text") | |
fill_btn.click(fn=fill_sample_text, inputs=[], outputs=input_textbox) | |
tag_btn.click( | |
fn=process_text, | |
inputs=input_textbox, | |
outputs=[output_html, output_count, output_counts_by_type] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |