import gradio as gr from transformers import VisionEncoderDecoderModel, DonutProcessor from PIL import Image import torch # Load the model and processor model_checkpoint_path = "Muhammad2019abdelfattah/Unichart_Fine-tuning" model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint_path) processor = DonutProcessor.from_pretrained(model_checkpoint_path) # Assuming DonutProcessor is used device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def generate_summary(image: Image.Image) -> str: try: # Define the input prompt for summarization input_prompt = " " # Load and process the image img = image.convert("RGB") pixel_values = processor(img, return_tensors="pt").pixel_values.to(device) # Encode the input prompt decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device) # Generate the summary outputs = model.generate( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, max_length=512, # Adjust max_length as needed early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=4, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # Decode the output sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") summary = sequence.split("")[1].strip() return summary except Exception as e: print(f"An error occurred: {e}") return "An error occurred during summarization." # Create Gradio interface iface = gr.Interface( fn=generate_summary, # Function to call inputs=gr.Image(type="pil"), # Input type (image) outputs="text", # Output type (text) title="Chart Summarization", description="Upload a chart image to get a summary based on the image content." ) # Launch the Gradio interface on an automatically selected port iface.launch(share=True)