Spaces:
Sleeping
Sleeping
import os | |
import openai | |
import gradio as gr | |
import json | |
from typing import Optional, Dict, Any, List | |
import plotly.graph_objects as go | |
import instructor | |
from models import ArtHistoryResponse, Location | |
from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response | |
class ArtExplorer: | |
def __init__(self): | |
# Initialize OpenAI client with instructor patch | |
self.client = instructor.patch(openai.OpenAI( | |
base_url="https://api.groq.com/openai/v1", | |
api_key=os.environ.get("GROQ_API_KEY") | |
)) | |
self.current_state = { | |
"zoom_level": 0, | |
"selections": {} | |
} | |
def create_map(self, locations: List[Location]) -> go.Figure: | |
"""Create a Plotly map figure from location data""" | |
if not locations: | |
locations = [Location( | |
name="Paris", | |
lat=48.8566, | |
lon=2.3522, | |
relevance="Default location" | |
)] | |
fig = go.Figure(go.Scattermapbox( | |
lat=[loc.lat for loc in locations], | |
lon=[loc.lon for loc in locations], | |
mode='markers', | |
marker=go.scattermapbox.Marker(size=10), | |
text=[loc.name for loc in locations] | |
)) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
mapbox=dict( | |
center=dict(lat=48.8566, lon=2.3522), | |
zoom=4 | |
), | |
margin=dict(r=0, t=0, l=0, b=0) | |
) | |
return fig | |
def get_llm_response(self, query: str, zoom_context: Optional[Dict[str, Any]] = None) -> ArtHistoryResponse: | |
"""Get response from LLM with proper validation""" | |
try: | |
print("\n=== Starting LLM Request ===") | |
print(f"Input query: {query}") | |
print(f"Zoom context: {zoom_context}") | |
current_zoom_states = { | |
"temporal": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"geographical": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"style": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"subject": {"level": self.current_state["zoom_level"], "selection": ""} | |
} | |
if zoom_context: | |
for key, value in zoom_context.items(): | |
if key in current_zoom_states: | |
current_zoom_states[key]["selection"] = value | |
messages = [ | |
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."}, | |
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format( | |
user_query=query, | |
current_zoom_states=json.dumps(current_zoom_states, indent=2) | |
)} | |
] | |
print("\nSending request to LLM...") | |
response = self.client.chat.completions.create( | |
model="mixtral-8x7b-32768", | |
messages=messages, | |
temperature=0.1, | |
max_tokens=2048, | |
response_model=ArtHistoryResponse | |
) | |
print("\nReceived validated response from LLM") | |
return response | |
except Exception as e: | |
print(f"\nError in LLM response: {str(e)}") | |
print(f"Full error details: {e.__class__.__name__}") | |
import traceback | |
print(traceback.format_exc()) | |
return self.get_default_response() | |
def get_default_response(self) -> ArtHistoryResponse: | |
"""Return default response when LLM fails""" | |
return ArtHistoryResponse(**CONTEXTUAL_ZOOM_default_response) | |
def create_interface(self) -> gr.Blocks: | |
"""Create the Gradio interface""" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Art History Explorer") | |
with gr.Row(): | |
query = gr.Textbox( | |
label="Enter your art history query", | |
placeholder="e.g., Napoleon wars, Renaissance Italy" | |
) | |
search_btn = gr.Button("Explore") | |
with gr.Row(): | |
# Temporal axis | |
with gr.Column(): | |
time_slider = gr.Slider( | |
minimum=1000, | |
maximum=2024, | |
label="Time Period", | |
interactive=True | |
) | |
time_explanation = gr.Markdown() | |
time_zoom = gr.Button("π Zoom Time Period") | |
# Geographical axis | |
with gr.Column(): | |
map_plot = gr.Plot(label="Geographic Location") | |
geo_explanation = gr.Markdown() | |
geo_zoom = gr.Button("π Zoom Geography") | |
with gr.Row(): | |
# Style axis | |
style_select = gr.Dropdown( | |
choices=["Classical", "Modern"], | |
multiselect=True, | |
label="Artistic Styles", | |
allow_custom_value=True | |
) | |
style_explanation = gr.Markdown() | |
style_zoom = gr.Button("π Zoom Styles") | |
def initial_search(query: str) -> tuple: | |
"""Handle the initial search query""" | |
response = self.get_llm_response(query) | |
temporal_config = response.axis_configurations["temporal"].current_zoom | |
geographical_config = response.axis_configurations["geographical"].current_zoom | |
style_config = response.axis_configurations["style"].current_zoom | |
map_fig = self.create_map(geographical_config.locations) | |
return ( | |
temporal_config.range, | |
map_fig, | |
style_config.options, | |
temporal_config.explanation, | |
geographical_config.explanation, | |
style_config.explanation | |
) | |
def zoom_axis(query: str, axis_name: str, current_value: Any) -> tuple: | |
"""Handle zoom events for any axis""" | |
self.current_state["zoom_level"] += 1 | |
response = self.get_llm_response( | |
query, | |
zoom_context={axis_name: current_value} | |
) | |
axis_config = response.axis_configurations[axis_name].current_zoom | |
if axis_name == "temporal": | |
return ( | |
axis_config.range, | |
axis_config.explanation | |
) | |
elif axis_name == "geographical": | |
map_fig = self.create_map(axis_config.locations) | |
return ( | |
map_fig, | |
axis_config.explanation | |
) | |
else: # style | |
return ( | |
axis_config.options, | |
axis_config.explanation | |
) | |
# Connect event handlers | |
search_btn.click( | |
fn=initial_search, | |
inputs=[query], | |
outputs=[ | |
time_slider, | |
map_plot, | |
style_select, | |
time_explanation, | |
geo_explanation, | |
style_explanation | |
] | |
) | |
time_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "temporal", v), | |
inputs=[query, time_slider], | |
outputs=[time_slider, time_explanation] | |
) | |
geo_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "geographical", v), | |
inputs=[query, map_plot], | |
outputs=[map_plot, geo_explanation] | |
) | |
style_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "style", v), | |
inputs=[query, style_select], | |
outputs=[style_select, style_explanation] | |
) | |
return demo | |
def main(): | |
"""Main entry point""" | |
print("Starting Art History Explorer...") | |
explorer = ArtExplorer() | |
print("Created ArtExplorer instance") | |
demo = explorer.create_interface() | |
print("Created interface") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |
if __name__ == "__main__": | |
main() |