Encyclopedia / app.py
baconnier's picture
Update app.py
8d4adac verified
raw
history blame
7.71 kB
import os
import openai
import gradio as gr
import json
import plotly.graph_objects as go
from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response
class ArtExplorer:
def __init__(self):
self.client = openai.OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY")
)
self.current_state = {
"zoom_level": 0,
"selections": {}
}
def create_map(self, locations):
"""Create a Plotly map figure from location data"""
if not locations:
locations = [{"name": "Paris", "lat": 48.8566, "lon": 2.3522}]
fig = go.Figure(go.Scattermapbox(
lat=[loc.get('lat') for loc in locations],
lon=[loc.get('lon') for loc in locations],
mode='markers',
marker=go.scattermapbox.Marker(size=10),
text=[loc.get('name') for loc in locations]
))
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(
center=dict(lat=48.8566, lon=2.3522),
zoom=4
),
margin=dict(r=0, t=0, l=0, b=0)
)
return fig
def get_llm_response(self, query: str, zoom_context: dict = None) -> dict:
try:
current_zoom_states = {
"temporal": {"level": self.current_state["zoom_level"], "selection": ""},
"geographical": {"level": self.current_state["zoom_level"], "selection": ""},
"style": {"level": self.current_state["zoom_level"], "selection": ""},
"subject": {"level": self.current_state["zoom_level"], "selection": ""}
}
if zoom_context:
for key, value in zoom_context.items():
if key in current_zoom_states:
current_zoom_states[key]["selection"] = value
messages = [
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."},
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format(
user_query=query,
current_zoom_states=json.dumps(current_zoom_states, indent=2)
)}
]
print("Sending request to LLM...")
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.1,
max_tokens=2048
)
print("Received response from LLM")
result = json.loads(response.choices[0].message.content)
return result
except Exception as e:
print(f"Error in LLM response: {str(e)}")
return self.get_default_response()
def get_default_response(self):
return CONTEXTUAL_ZOOM_default_response
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# Art History Explorer")
with gr.Row():
query = gr.Textbox(
label="Enter your art history query",
placeholder="e.g., Napoleon wars, Renaissance Italy"
)
search_btn = gr.Button("Explore")
with gr.Row():
# Temporal axis
with gr.Column():
time_slider = gr.Slider(
minimum=1000,
maximum=2024,
label="Time Period",
interactive=True
)
time_explanation = gr.Markdown()
time_zoom = gr.Button("πŸ” Zoom Time Period")
# Geographical axis
with gr.Column():
map_plot = gr.Plot(label="Geographic Location")
geo_explanation = gr.Markdown()
geo_zoom = gr.Button("πŸ” Zoom Geography")
with gr.Row():
style_select = gr.Dropdown(
multiselect=True,
label="Artistic Styles"
)
style_explanation = gr.Markdown()
style_zoom = gr.Button("πŸ” Zoom Styles")
def initial_search(query):
"""
Handle the initial search query and return properly formatted outputs for Gradio
Returns individual values instead of a dictionary for Gradio components
"""
config = self.get_llm_response(query)
temporal_config = config["axis_configurations"]["temporal"]["current_zoom"]
geographical_config = config["axis_configurations"]["geographical"]["current_zoom"]
style_config = config["axis_configurations"]["style"]["current_zoom"]
map_fig = self.create_map(geographical_config["locations"])
# Return individual values in the order expected by Gradio
return (
temporal_config["range"], # time_slider
map_fig, # map_plot
style_config["options"], # style_select
temporal_config["explanation"], # time_explanation
geographical_config.get("explanation", ""), # geo_explanation
style_config["explanation"] # style_explanation
)
def zoom_axis(query, axis_name, current_value):
"""
Handle zoom events for any axis with proper return formatting
"""
self.current_state["zoom_level"] += 1
config = self.get_llm_response(
query,
zoom_context={axis_name: current_value}
)
axis_config = config["axis_configurations"][axis_name]["current_zoom"]
if axis_name == "temporal":
return (
axis_config["range"],
axis_config["explanation"]
)
elif axis_name == "geographical":
map_fig = self.create_map(axis_config["locations"])
return (
map_fig,
axis_config.get("explanation", "")
)
else: # style
return (
axis_config["options"],
axis_config["explanation"]
)
# Connect event handlers
search_btn.click(
fn=initial_search,
inputs=[query],
outputs=[
time_slider,
map_plot,
style_select,
time_explanation,
geo_explanation,
style_explanation
]
)
time_zoom.click(
fn=lambda q, v: zoom_axis(q, "temporal", v),
inputs=[query, time_slider],
outputs=[time_slider, time_explanation]
)
geo_zoom.click(
fn=lambda q, v: zoom_axis(q, "geographical", v),
inputs=[query, map_plot],
outputs=[map_plot, geo_explanation]
)
style_zoom.click(
fn=lambda q, v: zoom_axis(q, "style", v),
inputs=[query, style_select],
outputs=[style_select, style_explanation]
)
return demo
def main():
print("Starting initialization...")
explorer = ArtExplorer()
print("Created ArtExplorer instance")
demo = explorer.create_interface()
print("Created interface")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True # Set to True to create a public link
)
if __name__ == "__main__":
main()