Encyclopedia / app.py
baconnier's picture
Update app.py
d954da7 verified
raw
history blame
8.01 kB
import os
import openai
import gradio as gr
import json
import plotly.graph_objects as go
from variables import *
class ArtExplorer:
def __init__(self):
self.client = openai.OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY")
)
self.current_state = {
"zoom_level": 0,
"selections": {}
}
def create_map(self, locations):
"""Create a Plotly map figure from location data"""
if not locations:
locations = [{"name": "Paris", "lat": 48.8566, "lon": 2.3522}]
fig = go.Figure(go.Scattermapbox(
lat=[loc.get('lat') for loc in locations],
lon=[loc.get('lon') for loc in locations],
mode='markers',
marker=go.scattermapbox.Marker(size=10),
text=[loc.get('name') for loc in locations]
))
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(
center=dict(lat=48.8566, lon=2.3522),
zoom=4
),
margin=dict(r=0, t=0, l=0, b=0)
)
return fig
def get_llm_response(self, query: str, zoom_context: dict = None) -> dict:
try:
current_zoom_states = {
"temporal": {"level": self.current_state["zoom_level"], "selection": ""},
"geographical": {"level": self.current_state["zoom_level"], "selection": ""},
"style": {"level": self.current_state["zoom_level"], "selection": ""},
"subject": {"level": self.current_state["zoom_level"], "selection": ""}
}
if zoom_context:
for key, value in zoom_context.items():
if key in current_zoom_states:
current_zoom_states[key]["selection"] = value
messages = [
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."},
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format(
user_query=query,
current_zoom_states=json.dumps(current_zoom_states, indent=2)
)}
]
# Print the messages being sent to the LLM for debugging
print("Messages sent to LLM:")
for message in messages:
print(f"{message['role']}: {message['content']}")
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.1,
max_tokens=2048
)
# Print the raw response from the LLM for debugging
print("Raw response from LLM:")
print(response)
result = json.loads(response.choices[0].message.content)
# Print the parsed result for debugging
print("Parsed result from LLM:")
print(result)
return result
except json.JSONDecodeError as json_err:
print(f"JSON decode error: {str(json_err)}")
print(f"Response content: {response.choices[0].message.content}")
return self.get_default_response()
except Exception as e:
print(f"Error in LLM response: {str(e)}")
return self.get_default_response()
def get_default_response(self):
return CONTEXTUAL_ZOOM_default_response
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# Art History Explorer")
with gr.Row():
query = gr.Textbox(
label="Enter your art history query",
placeholder="e.g., Napoleon wars, Renaissance Italy"
)
search_btn = gr.Button("Explore")
with gr.Row():
# Temporal axis
with gr.Column():
time_slider = gr.Slider(
minimum=1000,
maximum=2024,
label="Time Period",
interactive=True
)
time_explanation = gr.Markdown()
time_zoom = gr.Button("πŸ” Zoom Time Period")
# Geographical axis
with gr.Column():
map_plot = gr.Plot(label="Geographic Location")
geo_explanation = gr.Markdown()
geo_zoom = gr.Button("πŸ” Zoom Geography")
with gr.Row():
style_select = gr.Dropdown(
multiselect=True,
label="Artistic Styles"
)
style_explanation = gr.Markdown()
style_zoom = gr.Button("πŸ” Zoom Styles")
def initial_search(query):
config = self.get_llm_response(query)
temporal_config = config["axis_configurations"]["temporal"]["current_zoom"]
geographical_config = config["axis_configurations"]["geographical"]["current_zoom"]
style_config = config["axis_configurations"]["style"]["current_zoom"]
map_fig = self.create_map(geographical_config["locations"])
return {
time_slider: temporal_config["range"],
map_plot: map_fig,
style_select: gr.Dropdown(choices=style_config["options"]),
time_explanation: temporal_config["explanation"],
geo_explanation: geographical_config.get("explanation", ""),
style_explanation: style_config["explanation"]
}
def zoom_axis(query, axis_name, current_value):
self.current_state["zoom_level"] += 1
config = self.get_llm_response(
query,
zoom_context={axis_name: current_value}
)
axis_config = config["axis_configurations"][axis_name]["current_zoom"]
if axis_name == "temporal":
return {
time_slider: axis_config["range"],
time_explanation: axis_config["explanation"]
}
elif axis_name == "geographical":
map_fig = self.create_map(axis_config["locations"])
return {
map_plot: map_fig,
geo_explanation: axis_config.get("explanation", "")
}
else: # style
return {
style_select: gr.Dropdown(choices=axis_config["options"]),
style_explanation: axis_config["explanation"]
}
# Connect event handlers
search_btn.click(
fn=initial_search,
inputs=[query],
outputs=[
time_slider,
map_plot,
style_select,
time_explanation,
geo_explanation,
style_explanation
]
)
time_zoom.click(
fn=lambda q, v: zoom_axis(q, "temporal", v),
inputs=[query, time_slider],
outputs=[time_slider, time_explanation]
)
geo_zoom.click(
fn=lambda q, v: zoom_axis(q, "geographical", v),
inputs=[query, map_plot],
outputs=[map_plot, geo_explanation]
)
style_zoom.click(
fn=lambda q, v: zoom_axis(q, "style", v),
inputs=[query, style_select],
outputs=[style_select, style_explanation]
)
return demo
if __name__ == "__main__":
explorer = ArtExplorer()
demo = explorer.create_interface()
demo.launch()