Spaces:
Runtime error
Runtime error
| import openai | |
| import json | |
| from pydantic import BaseModel, Field | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import torch | |
| import requests | |
| import spaces | |
| class PromptTuple(BaseModel): | |
| class Tuple(BaseModel): | |
| type: str = Field( | |
| description="The type of the tuple. One of entity, attribute, relation", | |
| example="attribute", | |
| ) | |
| type_detail: str = Field( | |
| description="""The detail of the type. For example: | |
| - Entity: whole (entire entity, e.g., chair), part (part of entity, e.g., back of chair). | |
| - Attribute: color (e.g., red book), type (e.g., aviator goggles), material (e.g., wooden chair), count (e.g., 5 geese), texture (e.g., rough surface), text rendering (e.g., letters “Macaroni”), shape (e.g., triangle block), size (e.g., large fence). | |
| - Relation: spatial (e.g., A next to B); action (A kicks B).""", | |
| example="color", | |
| ) | |
| semantics: list = Field( | |
| description="List of strings that explain the existence of type and type_detail in the tuple", | |
| example=["motorcycle", "blue"], | |
| ) | |
| tuples: list[Tuple] = Field( | |
| description="List of tuples. Maximum 8 tuples.", | |
| example=[ | |
| { | |
| "type": "attribute", | |
| "type_detail": "color", | |
| "semantics": ["motorcycle", "blue"], | |
| } | |
| ], | |
| ) | |
| class DSGPromptProcessor: | |
| def __init__(self, model_name="gpt-4o-mini"): | |
| self.client = openai.OpenAI() | |
| self.model_name = model_name | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.binary_vqa = AutoModelForCausalLM.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True).to(self.device, torch.float16) | |
| self.binary_vqa_processor = processor = AutoProcessor.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True) | |
| def generate_tuples(self, input_text: str) -> PromptTuple: | |
| system_message = """ | |
| Given an image caption, extract the relevant entities, attributes, and relations present in the caption, and structure them into JSON format according to the following schema: | |
| Each tuple contains the following information: | |
| - Id: A unique identifier for the tuple. | |
| - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
| - Type Detail: Provide additional details based on the selected type: | |
| - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
| - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". | |
| - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
| - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
| Example Input: "A blue motorcycle parked next to a red car." | |
| Example output: | |
| { | |
| "tuples": [ | |
| { | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["motorcycle"] | |
| }, | |
| { | |
| "type": "attribute", | |
| "type_detail": "color", | |
| "semantics": ["motorcycle", "blue"] | |
| }, | |
| { | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["car"] | |
| }, | |
| { | |
| "type": "attribute", | |
| "type_detail": "color", | |
| "semantics": ["car", "red"] | |
| }, | |
| { | |
| "type": "relation", | |
| "type_detail": "spatial", | |
| "semantics": ["motorcycle", "next to", "car"] | |
| } | |
| ] | |
| } | |
| The final JSON should contain a list of tuples, each describing a unique entity, attribute, or relation from the image caption. Each JSON should contain a maximum of 8 tuples. | |
| """ | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_message, | |
| }, | |
| { | |
| "role": "user", | |
| "content": input_text, | |
| }, | |
| ] | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| response_format={"type": "json_object"}, | |
| max_tokens=512, | |
| ) | |
| output = json.loads(response.choices[0].message.content) | |
| return PromptTuple(**output), response.usage.total_tokens | |
| def generate_dependencies(self, tuples: PromptTuple) -> dict: | |
| DEPENDENCY_PROMPT = """ | |
| Given the following tuples extracted from an image caption, determine the dependencies between the entities, attributes, and relations in the JSON format. | |
| Each tuple contains the following information: | |
| - Id: A unique identifier for the tuple. | |
| - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
| - Type Detail: Provide additional details based on the selected type: | |
| - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
| - Attribute: Specify the attribute type, such as "color," "type," "material," "count," "texture," "text rendering," "shape," or "size." | |
| - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
| - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
| Output is a dictionary where the key is the id of the tuple and the value is a list of ids that the tuple depends on. | |
| Example input: | |
| [ | |
| { | |
| "id": 1, | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["motorcycle"] | |
| }, | |
| { | |
| "id": 2, | |
| "type": "attribute", | |
| "type_detail": "color", | |
| "semantics": ["motorcycle", "blue"] | |
| }, | |
| { | |
| "id": 3, | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["car"] | |
| }, | |
| { | |
| "id": 4, | |
| "type": "attribute", | |
| "type_detail": "color", | |
| "semantics": ["car", "red"] | |
| }, | |
| { | |
| "id": 5, | |
| "type": "relation", | |
| "type_detail": "spatial", | |
| "semantics": ["motorcycle", "next to", "car"] | |
| } | |
| ] | |
| Example output: | |
| { | |
| "1": [], | |
| "2": [1], | |
| "3": [], | |
| "4": [3], | |
| "5": [1, 3] | |
| } | |
| """ | |
| input_obj = [{"id": i, **t.dict()} for i, t in enumerate(tuples.tuples)] | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": DEPENDENCY_PROMPT, | |
| }, | |
| { | |
| "role": "user", | |
| "content": json.dumps(input_obj), | |
| }, | |
| ] | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return ( | |
| json.loads(response.choices[0].message.content), | |
| response.usage.total_tokens, | |
| ) | |
| def generate_questions( | |
| self, prompt: str, tuples: list[dict], dependencies: dict | |
| ) -> list[str]: | |
| """Generate validate question based on tuples and dependencies. | |
| Args: | |
| prompt (str): a prompt describe the image | |
| tuples (list[dict]): each tuple is a unit of information extracted from the prompt | |
| dependencies (dict): the dependencies between tuples | |
| """ | |
| system_message = """ | |
| Task: Given a prompt that describe the image and a list of tuples extracted from the prompt. Generate questions based on tuple in natural language as a list. | |
| Each tuple contains the following information: | |
| - Id: A unique identifier for the tuple. | |
| - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
| - Type Detail: Provide additional details based on the selected type: | |
| - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
| - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". | |
| - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
| - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
| Output is a list of questions, each question corresponds to a tuple. The number of questions must be the same as the number of tuples. | |
| Example input: | |
| Prompt: "A traffic light and a signpost at a crossroads intersection near a waterway" | |
| Tuples: | |
| [ | |
| { | |
| "id": 1, | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["traffic light"] | |
| }, | |
| { | |
| "id": 2, | |
| "type": "entity", | |
| "type_detail": "whole", | |
| "semantics": ["signpost"] | |
| }, | |
| { | |
| "id": 3, | |
| "type": "relation", | |
| "type_detail": "spatial", | |
| "semantics": ["traffic light", "at", "crossroads intersection"] | |
| }, | |
| { | |
| "id": 4, | |
| "type": "relation", | |
| "type_detail": "spatial", | |
| "semantics": ["crossroads intersection", "near", "waterway"] | |
| } | |
| ] | |
| Dependencies: | |
| { | |
| "1": [], | |
| "2": [], | |
| "3": [1, 2], | |
| "4": [3] | |
| } | |
| Example output is a json object. Each question ask about the existence of the tuple in the prompt and the answer should always be yes. | |
| { | |
| "1": "Is there a light?", | |
| "2": "Is there a signpost?", | |
| "3": "Is the traffic light at a crossroads intersection?", | |
| "4": "Is the crossroads intersection near a waterway?" | |
| } | |
| """ | |
| user_str = f""" | |
| Prompt: {prompt} | |
| Tuples: {tuples} | |
| Dependencies: {dependencies} | |
| """ | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_message, | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_str, | |
| }, | |
| ] | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return ( | |
| json.loads(response.choices[0].message.content), | |
| response.usage.total_tokens, | |
| ) | |
| def find_layers(self, dep_dict): | |
| layers = [] | |
| remaining_keys = set(dep_dict.keys()) | |
| while remaining_keys: | |
| current_layer = [] | |
| for key in list(remaining_keys): | |
| # If all dependencies of the key are in previous layers | |
| if all( | |
| str(dep) in [k for layer in layers for k in layer] | |
| for dep in dep_dict[key] | |
| ): | |
| current_layer.append(key) | |
| # If no new layer is formed, break to avoid infinite loop | |
| if not current_layer: | |
| break | |
| # Add the current layer to the list of layers | |
| layers.append(current_layer) | |
| # Remove the keys that are now layered | |
| remaining_keys -= set(current_layer) | |
| if len(layers) == 3: | |
| break | |
| ordered_indexes = [item for sublist in layers for item in sublist] | |
| return ordered_indexes | |
| def _create_graph_questions(self, questions: dict, dependencies: dict) -> set: | |
| # create a question graph | |
| layered_indexes = self.find_layers(dependencies) | |
| print(layered_indexes) | |
| sorted_questions = [questions[i] for i in layered_indexes] | |
| return sorted_questions | |
| def get_reward( | |
| self, | |
| questions: list[str], | |
| dependencies: dict[list], | |
| images: list, | |
| mode="hybrid", | |
| ): | |
| """Get reward for the generated questions use structured question graph. | |
| Args: | |
| questions (list[str]): a list of questions generated based on the tuples | |
| dependencies (dict[list]): the dependencies between tuples | |
| images (list[str]): a list of image urls | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.binary_vqa.to(self.device) | |
| scores = {} | |
| sorted_questions = self._create_graph_questions(questions, dependencies) | |
| print(sorted_questions) | |
| for i in range(len(images)): | |
| scores[i] = [0] * len(sorted_questions) | |
| def get_reward_for_a_question( | |
| question: str, | |
| question_dependencies: list[int], | |
| image: Image.Image, | |
| prev_scores: list[int], | |
| ) -> float: | |
| if any([not (prev_scores[i] > 0.5) for i in question_dependencies]): | |
| print( | |
| f"Skipping question: {question}. It depends on {[sorted_questions[i] for i in range(len(question_dependencies))]} that was answered as No." | |
| ) | |
| return 0 | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Invalid image type") | |
| inputs = self.binary_vqa_processor(text=question, images=image, return_tensors="pt").to(self.device, torch.float16) | |
| decoder_input_ids = torch.LongTensor([[self.binary_vqa.language_model.config.pad_token_id, self.binary_vqa.language_model.config.decoder_start_token_id]]).to(self.device) | |
| outputs = self.binary_vqa( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| decoder_input_ids=decoder_input_ids | |
| ) | |
| logits = outputs.logits[:, -1] | |
| score = logits[0].sigmoid().item() | |
| print(f"The answer Yes has {score} probs") | |
| return score | |
| pbar = tqdm( | |
| total=len(sorted_questions) * len(images), | |
| desc=f"Calculating reward over {len(images)} images and {len(sorted_questions)} questions", | |
| ) | |
| for i, question in enumerate(sorted_questions): | |
| for j, image in enumerate(images): | |
| scores[j][i] = get_reward_for_a_question( | |
| question, dependencies[str(i)], image, scores[j] | |
| ) | |
| pbar.update(1) | |
| return scores, sorted_questions | |
| if __name__ == "__main__": | |
| processor = DSGPromptProcessor(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| input_text = "ghibli style image of a cat" | |
| tuples, tokens = processor.generate_tuples(input_text) | |
| print(tuples) | |
| dependencies, tokens = processor.generate_dependencies(tuples) | |
| print(dependencies) | |
| questions, tokens = processor.generate_questions( | |
| input_text, tuples.tuples, dependencies | |
| ) | |
| print(questions) | |
| reward = processor.get_reward(input_text, questions, dependencies, [image]) | |
| print(reward) | |