Spaces:
Sleeping
Sleeping
| """ | |
| Author: Khanh Phan | |
| Date: 2024-12-04 | |
| """ | |
| import colorsys | |
| import json | |
| import re | |
| import gradio as gr | |
| import openai | |
| from transformers import pipeline | |
| from src.application.config import ( | |
| AZUREOPENAI_CLIENT, | |
| ENTITY_BRIGHTNESS, | |
| ENTITY_DARKEN_COLOR, | |
| ENTITY_LIGHTEN_COLOR, | |
| ENTITY_SATURATION, | |
| GPT_ENTITY_MODEL, | |
| ) | |
| ner_pipeline = pipeline("ner") | |
| def extract_entities_gpt( | |
| original_text, | |
| compared_text, | |
| text_generation_model=GPT_ENTITY_MODEL, | |
| ) -> str: | |
| """ | |
| Extracts entity pairs with significantly different meanings between | |
| two texts using a GPT model. | |
| Args: | |
| original_text (str): The original text. | |
| compared_text (str): The paraphrased or compared text. | |
| text_generation_model (str, optional): The GPT model | |
| to use for entity extraction. | |
| Returns: | |
| str: The JSON-like string containing the extracted entity pairs, | |
| or an empty string if an error occurs. | |
| """ | |
| # Construct the prompt for the GPT model. | |
| # TODO: Move to config or prompt file | |
| prompt = f""" | |
| Compare the ORIGINAL TEXT and the COMPARED TEXT. | |
| Find entity pairs with significantly different meanings after paraphrasing. | |
| Focus only on these significantly changed entities. These include: | |
| * **Numerical changes:** e.g., "five" -> "ten," "10%" -> "50%" | |
| * **Time changes:** e.g., "Monday" -> "Sunday," "10th" -> "21st" | |
| * **Name changes:** e.g., "Tokyo" -> "New York," "Japan" -> "Japanese" | |
| * **Opposite meanings:** e.g., "increase" -> "decrease," "good" -> "bad" | |
| * **Semantically different words:** e.g., "car" -> "truck," "walk" -> "run" | |
| Exclude entities where the meaning remains essentially the same, | |
| even if the wording is different | |
| (e.g., "big" changed to "large," "house" changed to "residence"). | |
| Also exclude purely stylistic changes that don't affect the core meaning. | |
| Output the extracted entity pairs, one pair per line, | |
| in the following JSON-like list format without wrapping characters: | |
| [ | |
| ["ORIGINAL_TEXT_entity_1", "COMPARED_TEXT_entity_1"], | |
| ["ORIGINAL_TEXT_entity_2", "COMPARED_TEXT_entity_2"] | |
| ] | |
| If there are no entities that satisfy above condition, output empty list "[]". | |
| --- | |
| # ORIGINAL TEXT: | |
| {original_text} | |
| --- | |
| # COMPARED TEXT: | |
| {compared_text} | |
| """ | |
| # Generate text using the selected model | |
| try: | |
| # Send the prompt to the GPT model and get the response. | |
| response = AZUREOPENAI_CLIENT.chat.completions.create( | |
| model=text_generation_model, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| # Extract the generated content from the response. | |
| res = response.choices[0].message.content | |
| except openai.OpenAIError as e: | |
| print(f"Error interacting with OpenAI API: {e}") | |
| res = "" | |
| return res | |
| def read_json(json_string: str) -> list[list[str, str]]: | |
| """ | |
| Parses a JSON string and returns a list of unique entity pairs. | |
| Args: | |
| json_string (str): The JSON string to parse. | |
| Returns: | |
| List[List[str, str]]: A list of unique entity pairs, | |
| or an empty list if parsing fails. | |
| """ | |
| try: | |
| # Attempt to parse the JSON string into a Python object | |
| entities = json.loads(json_string) | |
| # Remove duplicates pair of entities | |
| unique_entities = [] | |
| for inner_list in entities: | |
| # Check if the current entity pair is already existed. | |
| if inner_list not in unique_entities: | |
| unique_entities.append(inner_list) | |
| return unique_entities | |
| except json.JSONDecodeError as e: | |
| print(f"Error decoding JSON: {e}") | |
| return [] | |
| def set_color_brightness( | |
| hex_color: str, | |
| brightness_factor: float = ENTITY_LIGHTEN_COLOR, | |
| ) -> str: | |
| """ | |
| Lightens a HEX color by increasing its brightness in HSV space. | |
| Args: | |
| hex_color (str): The HEX color code (e.g., "#RRGGBB"). | |
| factor (float, optional): The factor by which to increase brightness. | |
| Returns: | |
| str: The lightened HEX color code. | |
| """ | |
| # Remove the '#' prefix if present. | |
| hex_color = hex_color.lstrip("#") | |
| # Convert the HEX color to RGB (red, green, blue) integers. | |
| r, g, b = ( | |
| int(hex_color[0:2], 16), # Red component | |
| int(hex_color[2:4], 16), # Green component | |
| int(hex_color[4:6], 16), # Blue component | |
| ) | |
| # Convert RGB to HSV (hue, saturation, value/brightness) | |
| h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) | |
| # Increase the brightness by the specified factor, but cap it at 1.0. | |
| v = min(1.0, v * brightness_factor) | |
| # Convert the modified HSV back to RGB | |
| r, g, b = (int(c * 255) for c in colorsys.hsv_to_rgb(h, s, v)) | |
| # Convert the RGB values back to a HEX color code. | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| def generate_colors(index: int, total_colors: int = 20) -> str: | |
| """ | |
| Generates a unique, evenly spaced color for each index using HSL. | |
| Args: | |
| index (int): The index for which to generate a color. | |
| total_colors (int, optional): The total number of colors to | |
| distribute evenly. Defaults to 20. | |
| Returns: | |
| str: A HEX color code representing the generated color. | |
| """ | |
| # Calculate the hue value based on the index and total number of colors. | |
| # This ensures even distribution of hues across the color spectrum. | |
| hue = index / total_colors # Spread hues in range [0,1] | |
| # Convert HSL to RGB | |
| r, g, b = colorsys.hls_to_rgb(hue, ENTITY_SATURATION, ENTITY_BRIGHTNESS) | |
| # Scale the RGB values from [0, 1] to [0, 255] | |
| r, g, b = int(r * 255), int(g * 255), int(b * 255) | |
| # Convert to hex | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| def assign_colors_to_entities(entities: list) -> list[dict]: | |
| """ | |
| Assigns unique colors to each entity pair in a list. | |
| Args: | |
| entities (list): A list of entity pairs, | |
| where each pair is a list of two strings. | |
| Example: [["entity1_original", "entity1_compared"]] | |
| Returns: | |
| list: A list of dictionaries, | |
| where each dictionary contains | |
| - "color": the color of entity pair. | |
| - "input": the original entity string. | |
| - "source": the compared entity string. | |
| """ | |
| # Number of colors needed. | |
| total_colors = len(entities) | |
| # Assign colors to entities using their index. | |
| entities_colors = [] | |
| for index, entity in enumerate(entities): | |
| color = generate_colors(index, total_colors) | |
| # Append color and index to entities_colors | |
| entities_colors.append( | |
| {"color": color, "input": entity[0], "source": entity[1]}, | |
| ) | |
| return entities_colors | |
| def highlight_entities(text1: str, text2: str) -> list[dict]: | |
| """ | |
| Highlights entities with significant differences between | |
| two texts by assigning them unique colors. | |
| Args: | |
| text1 (str): input text. | |
| text2 (str): source text. | |
| Returns: | |
| list: A list of dictionaries, where each dictionary | |
| contains the highlighted entity information (color, input, source) | |
| or None if no significant entities are found or an error occurs. | |
| """ | |
| if text1 is None or text2 is None: | |
| return None | |
| # Extract entities with significant differences using a GPT model. | |
| entities_text = extract_entities_gpt(text1, text2) | |
| # Clean up the extracted entities string by removing wrapping characters. | |
| entities_text = entities_text.replace("```json", "").replace("```", "") | |
| # Parse the cleaned entities string into a Python list of entity pairs. | |
| entities = read_json(entities_text) | |
| # If no significant entities are found, return None. | |
| if len(entities) == 0: | |
| return None | |
| # Assign unique colors to the extracted entities. | |
| entities_with_colors = assign_colors_to_entities(entities) | |
| return entities_with_colors | |
| def apply_highlight( | |
| text: str, | |
| entities_with_colors: list[dict], | |
| key: str = "input", | |
| count: int = 0, | |
| ) -> tuple[str, list[int]]: | |
| """ | |
| Applies highlighting to specified entities within a text, | |
| assigning them unique colors and index labels. | |
| Args: | |
| text (str): The text to highlight. | |
| entities_with_colors (list): A list of dictionaries, | |
| where each dictionary represents an entity and its color. | |
| key (str, optional): The key in the entity dictionary that | |
| contains the entity text to highlight. | |
| count (int, optional): An offset to add to the index labels. | |
| Returns: | |
| tuple: | |
| - A tuple containing the highlighted text (str). | |
| - A list of index positions (list). | |
| """ | |
| if entities_with_colors is None: | |
| return text, [] | |
| # Start & end indices of highlighted entities. | |
| all_starts = [] | |
| all_ends = [] | |
| highlighted_text = "" | |
| temp_text = text | |
| # Apply highlighting to each entity. | |
| for index, entity in enumerate(entities_with_colors): | |
| highlighted_text = "" | |
| starts = [] | |
| ends = [] | |
| for m in re.finditer( | |
| # Word boundaries (\b) and escape special characters | |
| r"\b" + re.escape(entity[key]) + r"\b", | |
| temp_text, | |
| ): | |
| starts.append(m.start()) | |
| ends.append(m.end()) | |
| all_starts.extend(starts) | |
| all_ends.extend(ends) | |
| # Get the colors for each occurrence of the entity. | |
| color = entities_with_colors[index]["color"] | |
| # Lightened color for background text | |
| entity_color = set_color_brightness( | |
| color, | |
| brightness_factor=ENTITY_LIGHTEN_COLOR, | |
| ) | |
| # Darker color for background label (index) | |
| label_color = set_color_brightness( | |
| entity_color, | |
| brightness_factor=ENTITY_DARKEN_COLOR, | |
| ) | |
| # Apply highlighting to each occurrence of the entity. | |
| prev_end = 0 | |
| for start, end in zip(starts, ends): | |
| # Non-highlighted text before the entity. | |
| highlighted_text += temp_text[prev_end:start] | |
| # Create the index label with the specified color and style. | |
| index_label = ( | |
| f'<span_style="background-color:{label_color};color:white;' | |
| f"padding:1px_4px;border-radius:4px;font-size:12px;" | |
| f'font-weight:bold;display:inline-block;margin-right:4px;">{index + 1 + count}</span>' # noqa: E501 | |
| ) | |
| # Highlighted entity with the specified color and style. | |
| highlighted_text += ( | |
| f'<span_style="background-color:{entity_color};color:black;' | |
| f'border-radius:3px;font-size:14px;display:inline-block;">' | |
| f"{index_label}{temp_text[start:end]}</span>" | |
| ) | |
| prev_end = end | |
| # Append any remaining text after the last entity. | |
| highlighted_text += temp_text[prev_end:] | |
| # Update the temporary text with the highlighted text. | |
| temp_text = highlighted_text | |
| if highlighted_text == "": | |
| return text, [] | |
| # Get the index list of the highlighted text. | |
| highlight_idx_list = get_index_list(highlighted_text) | |
| return highlighted_text, highlight_idx_list | |
| def get_index_list(highlighted_text: str) -> list[int]: | |
| """ | |
| Generates a list of indices of highlighted words within a text. | |
| Args: | |
| highlighted_text (str): The text containing highlighted words | |
| wrapped in HTML-like span tags. | |
| Returns: | |
| list: A list of indices corresponding to the highlighted words. | |
| An empty list if no highlighted words are found. | |
| """ | |
| highlighted_index = [] | |
| start_index = None | |
| end_index = None | |
| words = highlighted_text.split() | |
| for index, word in enumerate(words): | |
| # Check if the word starts with a highlighted word. | |
| if word.startswith("<span_style"): | |
| start_index = index | |
| # Check if the word ends with a closing span tag | |
| if word.endswith("</span>"): | |
| end_index = index | |
| if start_index is not None: | |
| # Add the range of indices to the result list. | |
| highlighted_index.extend( | |
| list( | |
| range( | |
| start_index, | |
| end_index + 1, | |
| ), | |
| ), | |
| ) | |
| start_index = None | |
| end_index = None | |
| return highlighted_index | |
| def extract_entities(text: str): | |
| """ | |
| Extracts named entities from the given text. | |
| Args: | |
| text (str): The input text to extract entities from. | |
| Returns: | |
| list: A list of unique extracted entities (string). | |
| """ | |
| # Apply the Named Entity Recognition (NER) pipeline to the input text. | |
| output = ner_pipeline(text) | |
| # Extract words from the NER pipeline output. | |
| words = extract_words(output) | |
| # Combine subwords into complete words. | |
| words = combine_subwords(words) | |
| # Append the entities if it's not a duplicate. | |
| entities = [] | |
| for entity in words: | |
| if entity not in entities: | |
| entities.append(entity) | |
| return entities | |
| def extract_words(entities: list[dict]) -> list[str]: | |
| """ | |
| Extracts the words from a list of entities. | |
| Args: | |
| entities (list): A list of entities, | |
| where each entity is expected to be a dictionary | |
| containing a "word" key. | |
| Returns: | |
| list[str]: A list of words extracted from the entities. | |
| """ | |
| words = [] | |
| for entity in entities: | |
| words.append(entity["word"]) | |
| return words | |
| def combine_subwords(word_list): | |
| """ | |
| Combines subwords (indicated by "##") with the preceding word in a list. | |
| Args: | |
| word_list (list): A list of words, | |
| where subwords are prefixed with "##". | |
| Returns: | |
| list: A new list with subwords combined with their preceding words | |
| and hyphenated words combined. | |
| """ | |
| result = [] | |
| i = 0 | |
| while i < len(word_list): | |
| if word_list[i].startswith("##"): | |
| # Remove "##" and append the remaining to the previous word | |
| result[-1] += word_list[i][2:] | |
| elif i < len(word_list) - 2 and word_list[i + 1] == "-": | |
| # Combine the current word, the hyphen, and the next word. | |
| result.append(word_list[i] + word_list[i + 1] + word_list[i + 2]) | |
| i += 2 # Skip the next two words (hyphen and the following word) | |
| else: | |
| # If neither a subword nor a hyphenated word, | |
| # append the current word to the result list. | |
| result.append(word_list[i]) | |
| i += 1 | |
| return result | |
| original_text = """ | |
| Title: UK pledges support for Ukraine with 100-year pact | |
| Content: Sir Keir Starmer has pledged to put Ukraine in the "strongest | |
| possible position" on a trip to Kyiv where he signed a "landmark" | |
| 100-year pact with the war-stricken country. The prime minister's | |
| visit on Thursday was at one point marked by loud blasts and air | |
| raid sirens after a reported Russian drone attack was intercepted | |
| by Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
| Volodymyr Zelensky said Ukraine would send its own "hello back". | |
| An estimated one million people have been killed or wounded in the | |
| war so far. As the invasion reaches the end of its third year, Ukraine | |
| is losing territory in the east. Zelensky praised the UK's commitment | |
| on Thursday, amid wider concerns that the US President-elect Donald | |
| Trump, who is set to take office on Monday, could potentially reduce aid. | |
| """ | |
| compared_text = """ | |
| Title: Japan pledges support for Ukraine with 100-year pact | |
| Content: A leading Japanese figure has pledged to put Ukraine | |
| in the "strongest possible position" on a trip to Kyiv where | |
| they signed a "landmark" 100-year pact with the war-stricken country. | |
| The visit on Thursday was at one point marked by loud blasts and air | |
| raid sirens after a reported Russian drone attack was intercepted by | |
| Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
| Volodymyr Zelensky said Ukraine would send its own "hello back". | |
| An estimated one million people have been killed or wounded in the | |
| war so far. As the invasion reaches the end of its third year, Ukraine | |
| is losing territory in the east. Zelensky praised Japan's commitment | |
| on Thursday, amid wider concerns that the next US President, who is | |
| set to take office on Monday, could potentially reduce aid. | |
| """ | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Highlight Matching Parts Between Two Texts") | |
| text1_input = gr.Textbox( | |
| label="Text 1", | |
| lines=5, | |
| value=original_text, | |
| ) | |
| text2_input = gr.Textbox( | |
| label="Text 2", | |
| lines=5, | |
| value=compared_text, | |
| ) | |
| submit_button = gr.Button("Highlight Matches") | |
| output1 = gr.HTML("<br>" * 10) | |
| output2 = gr.HTML("<br>" * 10) | |
| submit_button.click( | |
| fn=highlight_entities, | |
| inputs=[text1_input, text2_input], | |
| outputs=[output1, output2], | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |