Graphify / app.py
ZahirJS's picture
Update app.py
7e08bc6 verified
raw
history blame
6.51 kB
import gradio as gr
import json
from graphviz import Digraph
import base64
def generate_concept_map(json_input: str) -> str:
"""
Generate concept map from JSON and return as base64 image
Args:
json_input (str): JSON describing the concept map structure.
Returns:
str: Base64 data URL of the generated concept map
"""
try:
if not json_input.strip():
return "Error: Empty input"
data = json.loads(json_input)
if 'central_node' not in data or 'nodes' not in data:
raise ValueError("Missing required fields: central_node or nodes")
# Create graph
dot = Digraph(
name='ConceptMap',
format='png',
graph_attr={
'rankdir': 'TB',
'splines': 'ortho',
'bgcolor': 'transparent'
}
)
# Central node (ellipse)
dot.node(
'central',
data['central_node'],
shape='ellipse',
style='filled',
fillcolor='#2196F3',
fontcolor='white',
fontsize='14'
)
# Process nodes (rectangles)
for node in data['nodes']:
node_id = node.get('id')
label = node.get('label')
relationship = node.get('relationship')
# Validate node
if not all([node_id, label, relationship]):
raise ValueError(f"Invalid node: {node}")
# Create main node (rectangle)
dot.node(
node_id,
label,
shape='box',
style='filled',
fillcolor='#4CAF50',
fontcolor='white',
fontsize='12'
)
# Connect to central node
dot.edge(
'central',
node_id,
label=relationship,
color='#9C27B0',
fontsize='10'
)
# Process subnodes (rectangles with lighter fill)
for subnode in node.get('subnodes', []):
sub_id = subnode.get('id')
sub_label = subnode.get('label')
sub_rel = subnode.get('relationship')
if not all([sub_id, sub_label, sub_rel]):
raise ValueError(f"Invalid subnode: {subnode}")
dot.node(
sub_id,
sub_label,
shape='box',
style='filled',
fillcolor='#FFA726',
fontcolor='white',
fontsize='10'
)
dot.edge(
node_id,
sub_id,
label=sub_rel,
color='#E91E63',
fontsize='8'
)
# Convert to base64 image
img_data = dot.pipe(format='png')
img_base64 = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{img_base64}"
except json.JSONDecodeError:
return "Error: Invalid JSON format"
except Exception as e:
return f"Error: {str(e)}"
if __name__ == "__main__":
# Sample JSON for placeholder
sample_json = """
{
"central_node": "Artificial Intelligence (AI)",
"nodes": [
{
"id": "ml",
"label": "Machine Learning",
"relationship": "core_component",
"subnodes": [
{
"id": "sl",
"label": "Supervised Learning",
"relationship": "type_of",
"subnodes": [
{"id": "reg", "label": "Regression", "relationship": "technique"},
{"id": "clf", "label": "Classification", "relationship": "technique"}
]
},
{
"id": "ul",
"label": "Unsupervised Learning",
"relationship": "type_of",
"subnodes": [
{"id": "clus", "label": "Clustering", "relationship": "technique"},
{"id": "dim", "label": "Dimensionality Reduction", "relationship": "technique"}
]
}
]
},
{
"id": "nlp",
"label": "Natural Language Processing",
"relationship": "application_area",
"subnodes": [
{
"id": "sa",
"label": "Sentiment Analysis",
"relationship": "task",
"subnodes": [
{"id": "tc", "label": "Text Classification", "relationship": "method"},
{"id": "absa", "label": "Aspect-Based Sentiment Analysis", "relationship": "method"}
]
},
{
"id": "tr",
"label": "Translation",
"relationship": "task",
"subnodes": [
{"id": "nmt", "label": "Neural Machine Translation", "relationship": "method"},
{"id": "rbt", "label": "Rule-Based Translation", "relationship": "method"}
]
}
]
},
{
"id": "cv",
"label": "Computer Vision",
"relationship": "application_area",
"subnodes": [
{
"id": "od",
"label": "Object Detection",
"relationship": "task",
"subnodes": [
{"id": "yolo", "label": "YOLO", "relationship": "algorithm"},
{"id": "rcnn", "label": "R-CNN", "relationship": "algorithm"}
]
}
]
}
]
}
"""
demo = gr.Interface(
fn=generate_concept_map,
inputs=gr.Textbox(
value=sample_json, # Pre-filled sample JSON
placeholder="Paste structured JSON here...",
label="JSON Input",
lines=15
),
outputs=gr.Image(
label="Concept Map",
type="filepath",
interactive=False
),
title="Advanced Concept Map Generator",
description="Create complex concept maps from JSON with direct image output"
)
demo.launch(
mcp_server=True,
share=False,
server_port=7860,
server_name="0.0.0.0"
)