Spaces:
Running
Running
Hakyung Sung
commited on
Commit
·
0bd74cf
1
Parent(s):
f54d77f
Correct app.py
Browse files
app.py
CHANGED
@@ -18,143 +18,73 @@ model_path = snapshot_download(model_repo) # Assumes the repo is public; add to
|
|
18 |
nlp = spacy.load(os.path.join(model_path, 'model-best'))
|
19 |
|
20 |
# Make sure the pipeline can split into sentences
|
|
|
|
|
|
|
21 |
if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
|
22 |
nlp.add_pipe('sentencizer')
|
23 |
|
24 |
def get_highlighted_text(doc):
|
25 |
"""
|
26 |
-
|
27 |
-
|
28 |
-
Returns a list of HTML strings (one per sentence).
|
29 |
"""
|
30 |
highlighted_sentences = []
|
31 |
for sent in doc.sents:
|
32 |
-
sent_start = sent.start_char
|
33 |
-
sent_end = sent.end_char
|
34 |
sent_text = sent.text
|
35 |
-
#
|
36 |
-
ents_in_sent = [ent for ent in doc.ents if ent.start_char >=
|
37 |
-
|
38 |
if ents_in_sent:
|
39 |
-
# Process entities
|
40 |
-
# replacing text doesn’t mess up subsequent character indices.
|
41 |
ents_in_sent = sorted(ents_in_sent, key=lambda x: x.start_char, reverse=True)
|
42 |
s = sent_text
|
43 |
for ent in ents_in_sent:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
# Wrap the entity
|
48 |
s = (
|
49 |
s[:ent_start]
|
50 |
-
+ f'<span
|
|
|
|
|
51 |
+ s[ent_end:]
|
52 |
)
|
53 |
highlighted_sentences.append(s)
|
54 |
else:
|
55 |
highlighted_sentences.append(sent_text)
|
56 |
-
|
57 |
-
|
58 |
-
def create_tag_count_plot(tag_counts):
|
59 |
-
"""
|
60 |
-
Create a Plotly bar chart that shows the count for each entity tag.
|
61 |
-
Returns the bar chart as a base64 encoded PNG image.
|
62 |
-
"""
|
63 |
-
sorted_tags = sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)
|
64 |
-
tags, counts = zip(*sorted_tags)
|
65 |
-
fig = go.Figure(data=[
|
66 |
-
go.Bar(
|
67 |
-
x=tags,
|
68 |
-
y=counts,
|
69 |
-
text=counts,
|
70 |
-
textposition='auto',
|
71 |
-
marker=dict(color='#8ABB40'),
|
72 |
-
hoverinfo='x+y'
|
73 |
-
)
|
74 |
-
])
|
75 |
-
fig.update_layout(
|
76 |
-
xaxis_title='Tag',
|
77 |
-
yaxis_title='Count',
|
78 |
-
template='none',
|
79 |
-
font=dict(size=10, family="Arial, sans-serif"),
|
80 |
-
xaxis_tickangle=-45,
|
81 |
-
margin=dict(l=50, r=50, t=50, b=80),
|
82 |
-
plot_bgcolor='white',
|
83 |
-
paper_bgcolor='white',
|
84 |
-
width=400,
|
85 |
-
height=550
|
86 |
-
)
|
87 |
-
img_buffer = io.BytesIO()
|
88 |
-
pio.write_image(fig, img_buffer, format='png')
|
89 |
-
img_buffer.seek(0)
|
90 |
-
plot_b64 = base64.b64encode(img_buffer.getvalue()).decode('utf8')
|
91 |
-
return plot_b64
|
92 |
-
|
93 |
-
def base64_to_pil(b64_str):
|
94 |
-
"""
|
95 |
-
Convert a base64 encoded string to a PIL Image.
|
96 |
-
"""
|
97 |
-
img_data = base64.b64decode(b64_str)
|
98 |
-
return Image.open(io.BytesIO(img_data))
|
99 |
|
100 |
def process_text(input_text):
|
101 |
"""
|
102 |
-
Process the user
|
103 |
-
|
104 |
-
- Otherwise, produce HTML with highlighted entities and a bar chart image of tag counts.
|
105 |
"""
|
106 |
-
error_message = ""
|
107 |
-
html_output = ""
|
108 |
-
plot_image = None
|
109 |
-
|
110 |
if not input_text.strip():
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
doc = nlp(input_text)
|
115 |
-
sentences
|
116 |
-
if len(
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
# Generate highlighted sentences and join them with HTML breaks.
|
125 |
-
highlighted = get_highlighted_text(doc)
|
126 |
-
html_output = "<br><br>".join(highlighted)
|
127 |
-
|
128 |
-
# Get a counter for entity tags and create a bar chart.
|
129 |
-
tag_counts = Counter([ent.label_ for ent in doc.ents])
|
130 |
-
plot_b64 = create_tag_count_plot(tag_counts)
|
131 |
-
plot_image = base64_to_pil(plot_b64)
|
132 |
-
|
133 |
-
return html_output, plot_image, error_message
|
134 |
|
135 |
# Build the Gradio interface.
|
136 |
with gr.Blocks() as demo:
|
137 |
-
gr.Markdown("#
|
138 |
-
gr.Markdown(
|
139 |
-
|
140 |
-
)
|
141 |
-
|
142 |
-
# Input textbox for user text.
|
143 |
input_textbox = gr.Textbox(lines=10, label="Input Text", placeholder="Enter text here...")
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
# Three outputs: highlighted text (as HTML), tag count plot image, and error message display.
|
149 |
-
highlighted_output = gr.HTML(label="Highlighted Sentences")
|
150 |
-
tag_plot_output = gr.Image(label="Tag Count Plot")
|
151 |
-
error_output = gr.Textbox(label="Error Message", interactive=False)
|
152 |
-
|
153 |
-
analyze_btn.click(
|
154 |
-
fn=process_text,
|
155 |
-
inputs=input_textbox,
|
156 |
-
outputs=[highlighted_output, tag_plot_output, error_output]
|
157 |
-
)
|
158 |
|
159 |
if __name__ == "__main__":
|
160 |
-
demo.launch()
|
|
|
18 |
nlp = spacy.load(os.path.join(model_path, 'model-best'))
|
19 |
|
20 |
# Make sure the pipeline can split into sentences
|
21 |
+
if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
|
22 |
+
nlp.add_pipe('sentencizer')
|
23 |
+
# If the pipeline is missing a sentence splitter, add one
|
24 |
if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
|
25 |
nlp.add_pipe('sentencizer')
|
26 |
|
27 |
def get_highlighted_text(doc):
|
28 |
"""
|
29 |
+
Wrap detected ASCs (entities) in each sentence with a span tag that has a custom inline style.
|
30 |
+
Here, we assume all entities from the model correspond to ASCs.
|
|
|
31 |
"""
|
32 |
highlighted_sentences = []
|
33 |
for sent in doc.sents:
|
|
|
|
|
34 |
sent_text = sent.text
|
35 |
+
# Get entities that are fully contained within the sentence.
|
36 |
+
ents_in_sent = [ent for ent in doc.ents if ent.start_char >= sent.start_char and ent.end_char <= sent.end_char]
|
|
|
37 |
if ents_in_sent:
|
38 |
+
# Process entities in reverse order to avoid messing up character indices.
|
|
|
39 |
ents_in_sent = sorted(ents_in_sent, key=lambda x: x.start_char, reverse=True)
|
40 |
s = sent_text
|
41 |
for ent in ents_in_sent:
|
42 |
+
# Compute positions relative to the sentence start
|
43 |
+
ent_start = ent.start_char - sent.start_char
|
44 |
+
ent_end = ent.end_char - sent.start_char
|
45 |
+
# Wrap the entity in a span with a custom style. Adjust color & style as needed.
|
46 |
s = (
|
47 |
s[:ent_start]
|
48 |
+
+ f'<span style="background-color: #add8e6; font-weight: bold;" title="{ent.label_}">'
|
49 |
+
+ s[ent_start:ent_end]
|
50 |
+
+ '</span>'
|
51 |
+ s[ent_end:]
|
52 |
)
|
53 |
highlighted_sentences.append(s)
|
54 |
else:
|
55 |
highlighted_sentences.append(sent_text)
|
56 |
+
# Join sentences with HTML breaks so the output preserves sentence separations.
|
57 |
+
return "<br><br>".join(highlighted_sentences)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def process_text(input_text):
|
60 |
"""
|
61 |
+
Process the user input text to detect and tag ASCs.
|
62 |
+
Returns an HTML string with tagged entities.
|
|
|
63 |
"""
|
|
|
|
|
|
|
|
|
64 |
if not input_text.strip():
|
65 |
+
return "No text provided. Please enter some text."
|
66 |
+
|
|
|
67 |
doc = nlp(input_text)
|
68 |
+
# Check if there are sentences; if not, return a message.
|
69 |
+
if len(list(doc.sents)) == 0:
|
70 |
+
return "Please enter at least one sentence."
|
71 |
+
# If no entities (ASCs) are found, let the user know.
|
72 |
+
if not doc.ents:
|
73 |
+
return "No ASCs were detected."
|
74 |
+
|
75 |
+
# Get the HTML with highlighted ASCs.
|
76 |
+
return get_highlighted_text(doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Build the Gradio interface.
|
79 |
with gr.Blocks() as demo:
|
80 |
+
gr.Markdown("# ASC Tagger")
|
81 |
+
gr.Markdown("Enter some text to have ASCs tagged (highlighted with a custom color scheme).")
|
82 |
+
|
|
|
|
|
|
|
83 |
input_textbox = gr.Textbox(lines=10, label="Input Text", placeholder="Enter text here...")
|
84 |
+
output_html = gr.HTML(label="Tagged Text")
|
85 |
+
tag_btn = gr.Button("Tag ASCs")
|
86 |
+
|
87 |
+
tag_btn.click(fn=process_text, inputs=input_textbox, outputs=output_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
+
demo.launch()
|