Spaces:
Running
Running
| import random | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch # Need this for torch.no_grad() | |
| from datasets import load_dataset | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import ( | |
| AutoProcessor, | |
| Qwen2_5_VLForConditionalGeneration, | |
| TextIteratorStreamer, | |
| ) | |
| from trl import ModelConfig | |
| def get_eval_dataset(): | |
| full_dataset = load_dataset("sunildkumar/message-decoding-words-and-sequences")[ | |
| "train" | |
| ] | |
| full_dataset = full_dataset.shuffle(seed=42) | |
| # split the dataset with the same seed as used in the training script | |
| splits = full_dataset.train_test_split(test_size=0.1, seed=42) | |
| test_dataset = splits["test"] | |
| return test_dataset | |
| def load_model_and_tokenizer(): | |
| model_config = ModelConfig( | |
| model_name_or_path="Groundlight/message-decoding-r1", | |
| torch_dtype="bfloat16", | |
| use_peft=False, | |
| ) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| pretrained_model_name_or_path=model_config.model_name_or_path, | |
| torch_dtype=model_config.torch_dtype, | |
| use_cache=False, | |
| device_map="auto", # Force CPU usage | |
| ) | |
| # put model in eval mode | |
| model.eval() | |
| processor = AutoProcessor.from_pretrained( | |
| model_config.model_name_or_path, padding_side="left" | |
| ) | |
| return model, processor | |
| # Move resource loading inside a function | |
| def load_resources(): | |
| global eval_dataset, model, processor | |
| eval_dataset = get_eval_dataset() | |
| model, processor = load_model_and_tokenizer() | |
| def show_random_example(): | |
| # Get a random example | |
| random_idx = random.randint(0, len(eval_dataset) - 1) | |
| example = eval_dataset[random_idx] | |
| # Return image for display, mapping for state, and image for state | |
| return example["image"], example["mapping"], example["image"] | |
| def prepare_model_input(image, mapping, processor, submitted_word): | |
| """ | |
| Prepare the input for the model using the mapping, processor, and submitted word. | |
| Args: | |
| image: The decoder image to use | |
| mapping (dict): The mapping data from the dataset | |
| processor: The model's processor/tokenizer | |
| submitted_word (str): The word submitted by the user | |
| Returns: | |
| dict: The processed inputs ready for the model | |
| """ | |
| decoded_message = submitted_word.lower() | |
| print(f"Decoded message: {decoded_message}") | |
| # reverse the decoder to encode the word | |
| encoder = {v: k for k, v in mapping.items()} | |
| print(f"Encoder: {encoder}") | |
| # leaving the space as is | |
| coded_message = [encoder[c] if c in encoder else c for c in decoded_message] | |
| print(f"Coded message: {coded_message}") | |
| # add spaces between each character to prevent tokenization issues | |
| coded_message = " ".join(coded_message) | |
| instruction = ( | |
| "Use the decoder in the image to decode a coded message." | |
| "The decoded message will be one or more words. Underscore characters " | |
| '("_") in the coded message should be mapped to a space (" ") when decoding.' | |
| ) | |
| ending = ( | |
| "Show your work in <think> </think> tags and return the answer in <answer> </answer> tags. " | |
| "While thinking, you must include a section with the decoded characters using <chars></chars> tags. " | |
| "The <chars> section should include the decoded characters in the order they are decoded. It should include the " | |
| "underscore character wherever there is a space in the decoded message. For example, if the coded message is " | |
| "a b c _ d e f, the chars section might be <chars> c a t _ d o g </chars>. You can think about the problem for " | |
| "as long as you'd like. While thinking, you should robustly verify your solution. Once you are done thinking, " | |
| f"provide your answer in the <answer> section, e.g. <answer> cat dog </answer>. The coded message is: {coded_message}." | |
| ) | |
| instruction = f"{instruction} {ending}" | |
| print(f"Instruction: {instruction}") | |
| r1_messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are a helpful assistant. You first think about the reasoning process in the mind and then provide the user with the answer.", | |
| } | |
| ], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": instruction}, | |
| ], | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| {"type": "text", "text": "Let me solve this step by step.\n<think>"} | |
| ], | |
| }, | |
| ] | |
| texts = processor.apply_chat_template( | |
| r1_messages, continue_final_message=True, tokenize=False | |
| ) | |
| image_input, _ = process_vision_info(r1_messages) | |
| image_input = [image_input] | |
| batch = processor( | |
| text=texts, | |
| images=image_input, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| return batch | |
| def encode_word(word, mapping): | |
| """ | |
| Encode a word using the given mapping. | |
| """ | |
| if not word or not mapping: | |
| return "" | |
| word = word.lower() | |
| # reverse the decoder to encode the word | |
| encoder = {v: k for k, v in mapping.items()} | |
| # leaving the space as is | |
| coded_message = [encoder[c] if c in encoder else c for c in word] | |
| return " ".join(coded_message) | |
| def validate_and_submit(word, mapping): | |
| # Check if input contains only letters | |
| if not word.replace(" ", "").isalpha(): | |
| gr.Warning( | |
| "Invalid input! Please enter only English letters and spaces. No numbers or punctuation allowed." | |
| ) | |
| return ( | |
| gr.update(), # word input | |
| gr.update(), # submit button | |
| gr.update(interactive=False), # run button - disable but keep visible | |
| gr.update(visible=False), # encoded word display | |
| ) | |
| if not mapping: | |
| gr.Warning("Please generate a decoder first") | |
| return ( | |
| gr.update(), # word input | |
| gr.update(), # submit button | |
| gr.update(interactive=False), # run button - disable but keep visible | |
| gr.update(visible=False), # encoded word display | |
| ) | |
| word = word.lower() | |
| encoded_word = encode_word(word, mapping) | |
| # Only enable run button if we have a valid encoded word | |
| has_valid_encoded_word = bool(encoded_word.strip()) | |
| if not has_valid_encoded_word: | |
| gr.Warning( | |
| "Invalid input! The word contains characters that cannot be encoded with the current decoder." | |
| ) | |
| return ( | |
| gr.update(), # word input | |
| gr.update(), # submit button | |
| gr.update(interactive=False), # run button - disable but keep visible | |
| gr.update(visible=False), # encoded word display | |
| ) | |
| # Return updates for input, submit button, run button, and encoded word display | |
| return ( | |
| gr.update(value=word, interactive=False, label="Submitted Word"), | |
| gr.update(interactive=False), # Disable submit button | |
| gr.update( | |
| interactive=has_valid_encoded_word | |
| ), # Enable run button only if valid, but always visible | |
| gr.update( | |
| value=f"Encoded message: {encoded_word}", visible=has_valid_encoded_word | |
| ), # Show encoded message | |
| ) | |
| def prepare_for_inference(): | |
| """Setup function that runs before streaming starts""" | |
| return ( | |
| gr.update(value="", visible=True), # Clear and show output | |
| gr.update(interactive=False), # Disable run button | |
| gr.update(visible=True), # Show loading indicator | |
| ) | |
| def run_inference(word, image, mapping): | |
| """Main inference function, now focused just on generation""" | |
| if not word or not image or not mapping: | |
| raise gr.Error("Please submit a word and load a decoder first") | |
| # Prepare model input | |
| model_inputs = prepare_model_input(image, mapping, processor, word) | |
| model_inputs = {k: v.to("cuda") for k, v in model_inputs.items()} | |
| # Initialize streamer | |
| streamer = TextIteratorStreamer( | |
| tokenizer=processor, | |
| skip_special_tokens=True, | |
| decode_kwargs={"skip_special_tokens": True}, | |
| ) | |
| # Set up generation parameters | |
| generation_kwargs = dict( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=1.0, | |
| streamer=streamer, | |
| ) | |
| # Start generation in a separate thread with torch.no_grad() | |
| def generate_with_no_grad(): | |
| with torch.no_grad(): | |
| model.generate(**generation_kwargs) | |
| thread = Thread(target=generate_with_no_grad) | |
| thread.start() | |
| # Stream the output | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text | |
| thread.join() | |
| return generated_text | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| # Load resources when the app starts | |
| load_resources() | |
| gr.Markdown("# Groundlight's Visual Reasoning Model - Cryptogram Decoder") | |
| current_mapping = gr.State() | |
| current_image = gr.State() | |
| with gr.Row(): | |
| # Left column - Inputs | |
| with gr.Column(scale=1): | |
| # Instructions at the top | |
| instructions = """ | |
| Welcome! This demos Groundlight's visual reasoning model trained to decode cryptograms. To use the model: | |
| 1. Generate a decoder image. This will be provided to the model to decode your message. | |
| 2. Enter your message in the text box below. Your message should only contain English letters and spaces. | |
| Some examples: | |
| • hello world | |
| • i love reinforcement learning | |
| • groundlight makes computer vision easy | |
| 3. Encode your message. Just click the "Encode Message" button, and we'll handle encoding for you. | |
| 4. Run the model. You will see the model's reasoning process and the decoded message in <answer></answer> tags. | |
| """ | |
| gr.Textbox( | |
| value=instructions, | |
| label="Instructions", | |
| interactive=False, | |
| lines=4, | |
| ) | |
| # Image display component | |
| image_output = gr.Image(label="Decoder") | |
| # Button to load new random example | |
| next_button = gr.Button("Generate Random Decoder") | |
| # Text input for the word | |
| word_input = gr.Textbox( | |
| label="Enter your message", | |
| placeholder="Enter message here...", | |
| max_lines=1, | |
| show_copy_button=False, | |
| ) | |
| gr.Markdown( | |
| "Note: Only English letters and spaces are allowed. Please do not enter any numbers or punctuation." | |
| ) | |
| # Add encoded word display | |
| encoded_word_display = gr.Textbox( | |
| label="Encoded Message", | |
| interactive=False, | |
| visible=False, | |
| max_lines=1, | |
| show_copy_button=True, | |
| ) | |
| # Group submit and run buttons vertically | |
| with gr.Column(): | |
| submit_button = gr.Button("Encode Message") | |
| run_button = gr.Button("Run Model", interactive=False) | |
| # Right column - Outputs | |
| with gr.Column(scale=1): | |
| # Output area for model response | |
| model_output = gr.Textbox( | |
| label="Model Output", | |
| interactive=False, | |
| lines=40, | |
| max_lines=80, | |
| container=True, | |
| show_copy_button=True, | |
| visible=True, | |
| ) | |
| # Add loading indicator | |
| loading_indicator = gr.HTML(visible=False) | |
| # Event handlers | |
| next_button.click( | |
| fn=show_random_example, outputs=[image_output, current_mapping, current_image] | |
| ) | |
| # Validate word on submit and update interface | |
| submit_button.click( | |
| fn=validate_and_submit, | |
| inputs=[word_input, current_mapping], | |
| outputs=[word_input, submit_button, run_button, encoded_word_display], | |
| ) | |
| run_button.click( | |
| fn=prepare_for_inference, | |
| outputs=[model_output, run_button, loading_indicator], | |
| ).then( | |
| fn=run_inference, | |
| inputs=[word_input, current_image, current_mapping], | |
| outputs=model_output, | |
| api_name=False, | |
| ).then( | |
| lambda: ( | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=True, label="Enter your message"), | |
| gr.update(interactive=True), | |
| gr.update(visible=False), | |
| ), | |
| None, | |
| [ | |
| run_button, | |
| loading_indicator, | |
| word_input, | |
| submit_button, | |
| encoded_word_display, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| # for local testing | |
| # demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| # updates HF | |
| demo.launch() | |