Hakyung Sung commited on
Commit
0bd74cf
·
1 Parent(s): f54d77f

Correct app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -108
app.py CHANGED
@@ -18,143 +18,73 @@ model_path = snapshot_download(model_repo) # Assumes the repo is public; add to
18
  nlp = spacy.load(os.path.join(model_path, 'model-best'))
19
 
20
  # Make sure the pipeline can split into sentences
 
 
 
21
  if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
22
  nlp.add_pipe('sentencizer')
23
 
24
  def get_highlighted_text(doc):
25
  """
26
- For each sentence in the document, check if there are any entities.
27
- If so, insert HTML spans with a class and data attribute to highlight them.
28
- Returns a list of HTML strings (one per sentence).
29
  """
30
  highlighted_sentences = []
31
  for sent in doc.sents:
32
- sent_start = sent.start_char
33
- sent_end = sent.end_char
34
  sent_text = sent.text
35
- # Find entities that fall completely within this sentence.
36
- ents_in_sent = [ent for ent in doc.ents if ent.start_char >= sent_start and ent.end_char <= sent_end]
37
-
38
  if ents_in_sent:
39
- # Process entities from the end to the start so that
40
- # replacing text doesn’t mess up subsequent character indices.
41
  ents_in_sent = sorted(ents_in_sent, key=lambda x: x.start_char, reverse=True)
42
  s = sent_text
43
  for ent in ents_in_sent:
44
- ent_start = ent.start_char - sent_start
45
- ent_end = ent.end_char - sent_start
46
- ent_label = ent.label_
47
- # Wrap the entity text in a span tag with a CSS class and data attribute
48
  s = (
49
  s[:ent_start]
50
- + f'<span class="entity" data-entity="{ent_label}">{s[ent_start:ent_end]}</span>'
 
 
51
  + s[ent_end:]
52
  )
53
  highlighted_sentences.append(s)
54
  else:
55
  highlighted_sentences.append(sent_text)
56
- return highlighted_sentences
57
-
58
- def create_tag_count_plot(tag_counts):
59
- """
60
- Create a Plotly bar chart that shows the count for each entity tag.
61
- Returns the bar chart as a base64 encoded PNG image.
62
- """
63
- sorted_tags = sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)
64
- tags, counts = zip(*sorted_tags)
65
- fig = go.Figure(data=[
66
- go.Bar(
67
- x=tags,
68
- y=counts,
69
- text=counts,
70
- textposition='auto',
71
- marker=dict(color='#8ABB40'),
72
- hoverinfo='x+y'
73
- )
74
- ])
75
- fig.update_layout(
76
- xaxis_title='Tag',
77
- yaxis_title='Count',
78
- template='none',
79
- font=dict(size=10, family="Arial, sans-serif"),
80
- xaxis_tickangle=-45,
81
- margin=dict(l=50, r=50, t=50, b=80),
82
- plot_bgcolor='white',
83
- paper_bgcolor='white',
84
- width=400,
85
- height=550
86
- )
87
- img_buffer = io.BytesIO()
88
- pio.write_image(fig, img_buffer, format='png')
89
- img_buffer.seek(0)
90
- plot_b64 = base64.b64encode(img_buffer.getvalue()).decode('utf8')
91
- return plot_b64
92
-
93
- def base64_to_pil(b64_str):
94
- """
95
- Convert a base64 encoded string to a PIL Image.
96
- """
97
- img_data = base64.b64decode(b64_str)
98
- return Image.open(io.BytesIO(img_data))
99
 
100
  def process_text(input_text):
101
  """
102
- Process the user-input text:
103
- - If the text is empty or has no valid sentences/entities, set an error message.
104
- - Otherwise, produce HTML with highlighted entities and a bar chart image of tag counts.
105
  """
106
- error_message = ""
107
- html_output = ""
108
- plot_image = None
109
-
110
  if not input_text.strip():
111
- error_message = "No text provided. Please enter some text."
112
- return html_output, plot_image, error_message
113
-
114
  doc = nlp(input_text)
115
- sentences = list(doc.sents)
116
- if len(sentences) < 1:
117
- error_message = "Please enter at least one sentence."
118
- return html_output, plot_image, error_message
119
-
120
- if len(doc.ents) == 0:
121
- error_message = "No entities were detected. Please try again with a different input."
122
- return html_output, plot_image, error_message
123
-
124
- # Generate highlighted sentences and join them with HTML breaks.
125
- highlighted = get_highlighted_text(doc)
126
- html_output = "<br><br>".join(highlighted)
127
-
128
- # Get a counter for entity tags and create a bar chart.
129
- tag_counts = Counter([ent.label_ for ent in doc.ents])
130
- plot_b64 = create_tag_count_plot(tag_counts)
131
- plot_image = base64_to_pil(plot_b64)
132
-
133
- return html_output, plot_image, error_message
134
 
135
  # Build the Gradio interface.
136
  with gr.Blocks() as demo:
137
- gr.Markdown("# Named Entity Highlighter and Tag Counter")
138
- gr.Markdown(
139
- "Enter some text to visualize named entities. The app will highlight any detected entities in each sentence and show a bar chart of entity tag counts."
140
- )
141
-
142
- # Input textbox for user text.
143
  input_textbox = gr.Textbox(lines=10, label="Input Text", placeholder="Enter text here...")
144
-
145
- # Button to trigger the analysis.
146
- analyze_btn = gr.Button("Analyze Text")
147
-
148
- # Three outputs: highlighted text (as HTML), tag count plot image, and error message display.
149
- highlighted_output = gr.HTML(label="Highlighted Sentences")
150
- tag_plot_output = gr.Image(label="Tag Count Plot")
151
- error_output = gr.Textbox(label="Error Message", interactive=False)
152
-
153
- analyze_btn.click(
154
- fn=process_text,
155
- inputs=input_textbox,
156
- outputs=[highlighted_output, tag_plot_output, error_output]
157
- )
158
 
159
  if __name__ == "__main__":
160
- demo.launch()
 
18
  nlp = spacy.load(os.path.join(model_path, 'model-best'))
19
 
20
  # Make sure the pipeline can split into sentences
21
+ if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
22
+ nlp.add_pipe('sentencizer')
23
+ # If the pipeline is missing a sentence splitter, add one
24
  if 'parser' not in nlp.pipe_names and 'senter' not in nlp.pipe_names:
25
  nlp.add_pipe('sentencizer')
26
 
27
  def get_highlighted_text(doc):
28
  """
29
+ Wrap detected ASCs (entities) in each sentence with a span tag that has a custom inline style.
30
+ Here, we assume all entities from the model correspond to ASCs.
 
31
  """
32
  highlighted_sentences = []
33
  for sent in doc.sents:
 
 
34
  sent_text = sent.text
35
+ # Get entities that are fully contained within the sentence.
36
+ ents_in_sent = [ent for ent in doc.ents if ent.start_char >= sent.start_char and ent.end_char <= sent.end_char]
 
37
  if ents_in_sent:
38
+ # Process entities in reverse order to avoid messing up character indices.
 
39
  ents_in_sent = sorted(ents_in_sent, key=lambda x: x.start_char, reverse=True)
40
  s = sent_text
41
  for ent in ents_in_sent:
42
+ # Compute positions relative to the sentence start
43
+ ent_start = ent.start_char - sent.start_char
44
+ ent_end = ent.end_char - sent.start_char
45
+ # Wrap the entity in a span with a custom style. Adjust color & style as needed.
46
  s = (
47
  s[:ent_start]
48
+ + f'<span style="background-color: #add8e6; font-weight: bold;" title="{ent.label_}">'
49
+ + s[ent_start:ent_end]
50
+ + '</span>'
51
  + s[ent_end:]
52
  )
53
  highlighted_sentences.append(s)
54
  else:
55
  highlighted_sentences.append(sent_text)
56
+ # Join sentences with HTML breaks so the output preserves sentence separations.
57
+ return "<br><br>".join(highlighted_sentences)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def process_text(input_text):
60
  """
61
+ Process the user input text to detect and tag ASCs.
62
+ Returns an HTML string with tagged entities.
 
63
  """
 
 
 
 
64
  if not input_text.strip():
65
+ return "No text provided. Please enter some text."
66
+
 
67
  doc = nlp(input_text)
68
+ # Check if there are sentences; if not, return a message.
69
+ if len(list(doc.sents)) == 0:
70
+ return "Please enter at least one sentence."
71
+ # If no entities (ASCs) are found, let the user know.
72
+ if not doc.ents:
73
+ return "No ASCs were detected."
74
+
75
+ # Get the HTML with highlighted ASCs.
76
+ return get_highlighted_text(doc)
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Build the Gradio interface.
79
  with gr.Blocks() as demo:
80
+ gr.Markdown("# ASC Tagger")
81
+ gr.Markdown("Enter some text to have ASCs tagged (highlighted with a custom color scheme).")
82
+
 
 
 
83
  input_textbox = gr.Textbox(lines=10, label="Input Text", placeholder="Enter text here...")
84
+ output_html = gr.HTML(label="Tagged Text")
85
+ tag_btn = gr.Button("Tag ASCs")
86
+
87
+ tag_btn.click(fn=process_text, inputs=input_textbox, outputs=output_html)
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
+ demo.launch()