Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, DonutProcessor | |
from PIL import Image | |
import torch | |
# Load the model and processor | |
model_checkpoint_path = "Muhammad2019abdelfattah/Unichart_Fine-tuning" | |
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint_path) | |
processor = DonutProcessor.from_pretrained(model_checkpoint_path) # Assuming DonutProcessor is used | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def generate_summary(image: Image.Image) -> str: | |
try: | |
# Define the input prompt for summarization | |
input_prompt = "<summarize_chart> <s_answer>" | |
# Load and process the image | |
img = image.convert("RGB") | |
pixel_values = processor(img, return_tensors="pt").pixel_values.to(device) | |
# Encode the input prompt | |
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device) | |
# Generate the summary | |
outputs = model.generate( | |
pixel_values=pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=512, # Adjust max_length as needed | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=4, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
# Decode the output | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
summary = sequence.split("<s_answer>")[1].strip() | |
return summary | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return "An error occurred during summarization." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_summary, # Function to call | |
inputs=gr.Image(type="pil"), # Input type (image) | |
outputs="text", # Output type (text) | |
title="Chart Summarization", | |
description="Upload a chart image to get a summary based on the image content." | |
) | |
# Launch the Gradio interface on an automatically selected port | |
iface.launch(share=True) | |