Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
from scorer import DSGPromptProcessor | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
from PIL import Image | |
def draw_colored_graph(dependencies, questions, answers): | |
# Create a directed graph | |
G = nx.DiGraph() | |
# Add nodes with labels and colors based on answers | |
for node, question in questions.items(): | |
color = 'green' if answers[node] else 'red' | |
G.add_node(int(node), label=question, color=color) | |
# Add edges based on dependencies | |
for node, deps in dependencies.items(): | |
for dep in deps: | |
G.add_edge(dep, int(node)) | |
# Set node positions using a layout | |
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout' | |
# Draw nodes with custom colors and labels | |
node_colors = [G.nodes[node]['color'] for node in G.nodes()] | |
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black') | |
# Draw edges with arrows | |
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1') | |
# Draw labels | |
labels = nx.get_node_attributes(G, 'label') | |
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black') | |
# Save the graph as a Pillow image | |
buf = io.BytesIO() | |
plt.axis('off') | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def process_image(image, prompt): | |
tuples, _ = processor.generate_tuples(prompt) | |
dependencies, _ = processor.generate_dependencies(tuples) | |
questions, _ = processor.generate_questions( | |
input_text, tuples.tuples, dependencies | |
) | |
reward = processor.get_reward(input_text, questions, dependencies, [image]) | |
reward = reward[0] | |
answers = {i: v > 0.5 for i, v in enumerate(reward)} | |
graph_img = draw_colored_graph(dependencies, questions, answers) | |
return reward, f""" | |
Question: {questions}. | |
Reward per question: {reward}""" | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")], | |
outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")], | |
title="Image and Prompt Interface", | |
description="Upload an image and enter a prompt. The output is an image and text below it." | |
) | |
# Launch the Gradio app | |
interface.launch() | |