Encyclopedia / app.py
baconnier's picture
Update app.py
8f7fdc3 verified
raw
history blame
8.19 kB
import os
import openai
import gradio as gr
import json
from typing import Optional, Dict, Any, List
import plotly.graph_objects as go
import instructor
from models import ArtHistoryResponse, Location
from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response
# Enable instructor
instructor.patch()
class ArtExplorer:
def __init__(self):
self.client = openai.OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY")
)
self.current_state = {
"zoom_level": 0,
"selections": {}
}
def create_map(self, locations: List[Location]) -> go.Figure:
"""Create a Plotly map figure from location data"""
if not locations:
locations = [Location(
name="Paris",
lat=48.8566,
lon=2.3522,
relevance="Default location"
)]
fig = go.Figure(go.Scattermapbox(
lat=[loc.lat for loc in locations],
lon=[loc.lon for loc in locations],
mode='markers',
marker=go.scattermapbox.Marker(size=10),
text=[loc.name for loc in locations]
))
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(
center=dict(lat=48.8566, lon=2.3522),
zoom=4
),
margin=dict(r=0, t=0, l=0, b=0)
)
return fig
def get_llm_response(self, query: str, zoom_context: Optional[Dict[str, Any]] = None) -> ArtHistoryResponse:
"""Get response from LLM with proper validation"""
try:
print("\n=== Starting LLM Request ===")
print(f"Input query: {query}")
print(f"Zoom context: {zoom_context}")
current_zoom_states = {
"temporal": {"level": self.current_state["zoom_level"], "selection": ""},
"geographical": {"level": self.current_state["zoom_level"], "selection": ""},
"style": {"level": self.current_state["zoom_level"], "selection": ""},
"subject": {"level": self.current_state["zoom_level"], "selection": ""}
}
if zoom_context:
for key, value in zoom_context.items():
if key in current_zoom_states:
current_zoom_states[key]["selection"] = value
messages = [
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."},
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format(
user_query=query,
current_zoom_states=json.dumps(current_zoom_states, indent=2)
)}
]
print("\nSending request to LLM...")
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.1,
max_tokens=2048,
response_model=ArtHistoryResponse
)
print("\nReceived validated response from LLM")
return response
except Exception as e:
print(f"\nError in LLM response: {str(e)}")
print(f"Full error details: {e.__class__.__name__}")
import traceback
print(traceback.format_exc())
return self.get_default_response()
def get_default_response(self) -> ArtHistoryResponse:
"""Return default response when LLM fails"""
return ArtHistoryResponse(**CONTEXTUAL_ZOOM_default_response)
def create_interface(self) -> gr.Blocks:
"""Create the Gradio interface"""
with gr.Blocks() as demo:
gr.Markdown("# Art History Explorer")
with gr.Row():
query = gr.Textbox(
label="Enter your art history query",
placeholder="e.g., Napoleon wars, Renaissance Italy"
)
search_btn = gr.Button("Explore")
with gr.Row():
# Temporal axis
with gr.Column():
time_slider = gr.Slider(
minimum=1000,
maximum=2024,
label="Time Period",
interactive=True
)
time_explanation = gr.Markdown()
time_zoom = gr.Button("πŸ” Zoom Time Period")
# Geographical axis
with gr.Column():
map_plot = gr.Plot(label="Geographic Location")
geo_explanation = gr.Markdown()
geo_zoom = gr.Button("πŸ” Zoom Geography")
with gr.Row():
# Style axis
style_select = gr.Dropdown(
choices=["Classical", "Modern"],
multiselect=True,
label="Artistic Styles",
allow_custom_value=True
)
style_explanation = gr.Markdown()
style_zoom = gr.Button("πŸ” Zoom Styles")
def initial_search(query: str) -> tuple:
"""Handle the initial search query"""
response = self.get_llm_response(query)
temporal_config = response.axis_configurations["temporal"].current_zoom
geographical_config = response.axis_configurations["geographical"].current_zoom
style_config = response.axis_configurations["style"].current_zoom
map_fig = self.create_map(geographical_config.locations)
return (
temporal_config.range,
map_fig,
style_config.options,
temporal_config.explanation,
geographical_config.explanation,
style_config.explanation
)
def zoom_axis(query: str, axis_name: str, current_value: Any) -> tuple:
"""Handle zoom events for any axis"""
self.current_state["zoom_level"] += 1
response = self.get_llm_response(
query,
zoom_context={axis_name: current_value}
)
axis_config = response.axis_configurations[axis_name].current_zoom
if axis_name == "temporal":
return (
axis_config.range,
axis_config.explanation
)
elif axis_name == "geographical":
map_fig = self.create_map(axis_config.locations)
return (
map_fig,
axis_config.explanation
)
else: # style
return (
axis_config.options,
axis_config.explanation
)
# Connect event handlers
search_btn.click(
fn=initial_search,
inputs=[query],
outputs=[
time_slider,
map_plot,
style_select,
time_explanation,
geo_explanation,
style_explanation
]
)
time_zoom.click(
fn=lambda q, v: zoom_axis(q, "temporal", v),
inputs=[query, time_slider],
outputs=[time_slider, time_explanation]
)
geo_zoom.click(
fn=lambda q, v: zoom_axis(q, "geographical", v),
inputs=[query, map_plot],
outputs=[map_plot, geo_explanation]
)
style_zoom.click(
fn=lambda q, v: zoom_axis(q, "style", v),
inputs=[query, style_select],
outputs=[style_select, style_explanation]
)
return demo
def main():
"""Main entry point"""
print("Starting Art History Explorer...")
explorer = ArtExplorer()
print("Created ArtExplorer instance")
demo = explorer.create_interface()
print("Created interface")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
if __name__ == "__main__":
main()