toilaluan's picture
update
5f84a74
raw
history blame
2.47 kB
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
def draw_colored_graph(dependencies, questions, answers):
# Create a directed graph
G = nx.DiGraph()
# Add nodes with labels and colors based on answers
for node, question in questions.items():
color = 'green' if answers[node] else 'red'
G.add_node(int(node), label=question, color=color)
# Add edges based on dependencies
for node, deps in dependencies.items():
for dep in deps:
G.add_edge(dep, int(node))
# Set node positions using a layout
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout'
# Draw nodes with custom colors and labels
node_colors = [G.nodes[node]['color'] for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black')
# Draw edges with arrows
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')
# Draw labels
labels = nx.get_node_attributes(G, 'label')
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')
# Save the graph as a Pillow image
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return img
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")
def process_image(image, prompt):
tuples, _ = processor.generate_tuples(prompt)
dependencies, _ = processor.generate_dependencies(tuples)
questions, _ = processor.generate_questions(
input_text, tuples.tuples, dependencies
)
reward = processor.get_reward(input_text, questions, dependencies, [image])
reward = reward[0]
answers = {i: v > 0.5 for i, v in enumerate(reward)}
graph_img = draw_colored_graph(dependencies, questions, answers)
return reward, f"""
Question: {questions}.
Reward per question: {reward}"""
# Define the Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")],
outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")],
title="Image and Prompt Interface",
description="Upload an image and enter a prompt. The output is an image and text below it."
)
# Launch the Gradio app
interface.launch()