Encyclopedia / app.py
baconnier's picture
Update app.py
9fff027 verified
raw
history blame
6.89 kB
import os
import openai
import gradio as gr
import json
import plotly.graph_objects as go
class ArtExplorer:
def __init__(self):
self.client = openai.OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY")
)
self.current_state = {
"zoom_level": 0,
"selections": {}
}
def create_map(self, locations):
"""Create a Plotly map figure from location data"""
fig = go.Figure(go.Scattermapbox(
lat=[loc.get('lat') for loc in locations],
lon=[loc.get('lon') for loc in locations],
mode='markers',
marker=go.scattermapbox.Marker(size=10),
text=[loc.get('name') for loc in locations],
hovertemplate='<b>%{text}</b><extra></extra>'
))
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(
center=dict(lat=48.8566, lon=2.3522), # Default to Paris
zoom=4
),
margin=dict(r=0, t=0, l=0, b=0),
showlegend=False
)
return fig
def get_llm_response(self, query: str, zoom_context: dict = None) -> dict:
system_prompt = """You are an art history expert. Generate a structured JSON configuration for an interactive art exploration interface."""
user_prompt = f"""
Query: {query}
Zoom Level: {self.current_state['zoom_level']}
Current Selections: {json.dumps(zoom_context) if zoom_context else 'None'}
Return a JSON object with configurations for temporal, geographical, and style axes.
Include exact coordinates for geographical locations in this format:
"locations": [
{{"name": "Paris", "lat": 48.8566, "lon": 2.3522, "description": "Center of French art"}},
{{"name": "Florence", "lat": 43.7696, "lon": 11.2558, "description": "Renaissance art hub"}}
]
"""
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1,
max_tokens=1024
)
return json.loads(response.choices[0].message.content)
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# Art History Explorer")
with gr.Row():
query = gr.Textbox(
label="Enter your art history query",
placeholder="e.g., Napoleon wars, Renaissance Italy"
)
search_btn = gr.Button("Explore")
with gr.Row():
# Temporal axis
with gr.Column():
time_slider = gr.Slider(
minimum=1000,
maximum=2024,
label="Time Period",
interactive=True
)
time_explanation = gr.Markdown()
time_zoom = gr.Button("πŸ” Zoom Time Period")
# Geographical axis
with gr.Column():
map_plot = gr.Plot(label="Geographic Location")
geo_explanation = gr.Markdown()
geo_zoom = gr.Button("πŸ” Zoom Geography")
with gr.Row():
style_select = gr.Dropdown(
multiselect=True,
label="Artistic Styles"
)
style_explanation = gr.Markdown()
style_zoom = gr.Button("πŸ” Zoom Styles")
def initial_search(query):
config = self.get_llm_response(query)
self.current_state["selections"] = {}
self.current_state["zoom_level"] = 0
# Create map figure from locations
map_fig = self.create_map(config["geographical"]["locations"])
return {
time_slider: config["temporal"]["range"],
map_plot: map_fig,
style_select: gr.Dropdown(choices=config["style"]["options"]),
time_explanation: config["temporal"]["explanation"],
geo_explanation: config["geographical"]["explanation"],
style_explanation: config["style"]["explanation"]
}
def zoom_axis(query, axis_name, current_value):
self.current_state["zoom_level"] += 1
self.current_state["selections"][axis_name] = current_value
config = self.get_llm_response(
query,
zoom_context={axis_name: current_value}
)
updates = {}
if axis_name == "temporal":
updates.update({
time_slider: config["temporal"]["range"],
time_explanation: config["temporal"]["explanation"]
})
elif axis_name == "geographical":
map_fig = self.create_map(config["geographical"]["locations"])
updates.update({
map_plot: map_fig,
geo_explanation: config["geographical"]["explanation"]
})
elif axis_name == "style":
updates.update({
style_select: gr.Dropdown(choices=config["style"]["options"]),
style_explanation: config["style"]["explanation"]
})
return updates
# Connect event handlers
search_btn.click(
fn=initial_search,
inputs=[query],
outputs=[time_slider, map_plot, style_select,
time_explanation, geo_explanation, style_explanation]
)
time_zoom.click(
fn=lambda q, v: zoom_axis(q, "temporal", v),
inputs=[query, time_slider],
outputs=[time_slider, time_explanation]
)
geo_zoom.click(
fn=lambda q, v: zoom_axis(q, "geographical", v),
inputs=[query, map_plot],
outputs=[map_plot, geo_explanation]
)
style_zoom.click(
fn=lambda q, v: zoom_axis(q, "style", v),
inputs=[query, style_select],
outputs=[style_select, style_explanation]
)
return demo
if __name__ == "__main__":
explorer = ArtExplorer()
demo = explorer.create_interface()
demo.launch()