Spaces:
Sleeping
Sleeping
import os | |
import openai | |
import gradio as gr | |
import json | |
import plotly.graph_objects as go | |
from variables import * | |
class ArtExplorer: | |
def __init__(self): | |
self.client = openai.OpenAI( | |
base_url="https://api.groq.com/openai/v1", | |
api_key=os.environ.get("GROQ_API_KEY") | |
) | |
self.current_state = { | |
"zoom_level": 0, | |
"selections": {} | |
} | |
def create_map(self, locations): | |
"""Create a Plotly map figure from location data""" | |
if not locations: | |
locations = [{"name": "Paris", "lat": 48.8566, "lon": 2.3522}] | |
fig = go.Figure(go.Scattermapbox( | |
lat=[loc.get('lat') for loc in locations], | |
lon=[loc.get('lon') for loc in locations], | |
mode='markers', | |
marker=go.scattermapbox.Marker(size=10), | |
text=[loc.get('name') for loc in locations] | |
)) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
mapbox=dict( | |
center=dict(lat=48.8566, lon=2.3522), | |
zoom=4 | |
), | |
margin=dict(r=0, t=0, l=0, b=0) | |
) | |
return fig | |
def get_llm_response(self, query: str, zoom_context: dict = None) -> dict: | |
try: | |
current_zoom_states = { | |
"temporal": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"geographical": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"style": {"level": self.current_state["zoom_level"], "selection": ""}, | |
"subject": {"level": self.current_state["zoom_level"], "selection": ""} | |
} | |
if zoom_context: | |
for key, value in zoom_context.items(): | |
if key in current_zoom_states: | |
current_zoom_states[key]["selection"] = value | |
messages = [ | |
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."}, | |
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format( | |
user_query=query, | |
current_zoom_states=json.dumps(current_zoom_states, indent=2) | |
)} | |
] | |
# Print the messages being sent to the LLM for debugging | |
print("Messages sent to LLM:") | |
for message in messages: | |
print(f"{message['role']}: {message['content']}") | |
response = self.client.chat.completions.create( | |
model="mixtral-8x7b-32768", | |
messages=messages, | |
temperature=0.1, | |
max_tokens=2048 | |
) | |
# Print the raw response from the LLM for debugging | |
print("Raw response from LLM:") | |
print(response) | |
result = json.loads(response.choices[0].message.content) | |
# Print the parsed result for debugging | |
print("Parsed result from LLM:") | |
print(result) | |
return result | |
except json.JSONDecodeError as json_err: | |
print(f"JSON decode error: {str(json_err)}") | |
print(f"Response content: {response.choices[0].message.content}") | |
return self.get_default_response() | |
except Exception as e: | |
print(f"Error in LLM response: {str(e)}") | |
return self.get_default_response() | |
def get_default_response(self): | |
return CONTEXTUAL_ZOOM_default_response | |
def create_interface(self): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Art History Explorer") | |
with gr.Row(): | |
query = gr.Textbox( | |
label="Enter your art history query", | |
placeholder="e.g., Napoleon wars, Renaissance Italy" | |
) | |
search_btn = gr.Button("Explore") | |
with gr.Row(): | |
# Temporal axis | |
with gr.Column(): | |
time_slider = gr.Slider( | |
minimum=1000, | |
maximum=2024, | |
label="Time Period", | |
interactive=True | |
) | |
time_explanation = gr.Markdown() | |
time_zoom = gr.Button("π Zoom Time Period") | |
# Geographical axis | |
with gr.Column(): | |
map_plot = gr.Plot(label="Geographic Location") | |
geo_explanation = gr.Markdown() | |
geo_zoom = gr.Button("π Zoom Geography") | |
with gr.Row(): | |
style_select = gr.Dropdown( | |
multiselect=True, | |
label="Artistic Styles" | |
) | |
style_explanation = gr.Markdown() | |
style_zoom = gr.Button("π Zoom Styles") | |
def initial_search(query): | |
config = self.get_llm_response(query) | |
temporal_config = config["axis_configurations"]["temporal"]["current_zoom"] | |
geographical_config = config["axis_configurations"]["geographical"]["current_zoom"] | |
style_config = config["axis_configurations"]["style"]["current_zoom"] | |
map_fig = self.create_map(geographical_config["locations"]) | |
return { | |
time_slider: temporal_config["range"], | |
map_plot: map_fig, | |
style_select: gr.Dropdown(choices=style_config["options"]), | |
time_explanation: temporal_config["explanation"], | |
geo_explanation: geographical_config.get("explanation", ""), | |
style_explanation: style_config["explanation"] | |
} | |
def zoom_axis(query, axis_name, current_value): | |
self.current_state["zoom_level"] += 1 | |
config = self.get_llm_response( | |
query, | |
zoom_context={axis_name: current_value} | |
) | |
axis_config = config["axis_configurations"][axis_name]["current_zoom"] | |
if axis_name == "temporal": | |
return { | |
time_slider: axis_config["range"], | |
time_explanation: axis_config["explanation"] | |
} | |
elif axis_name == "geographical": | |
map_fig = self.create_map(axis_config["locations"]) | |
return { | |
map_plot: map_fig, | |
geo_explanation: axis_config.get("explanation", "") | |
} | |
else: # style | |
return { | |
style_select: gr.Dropdown(choices=axis_config["options"]), | |
style_explanation: axis_config["explanation"] | |
} | |
# Connect event handlers | |
search_btn.click( | |
fn=initial_search, | |
inputs=[query], | |
outputs=[ | |
time_slider, | |
map_plot, | |
style_select, | |
time_explanation, | |
geo_explanation, | |
style_explanation | |
] | |
) | |
time_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "temporal", v), | |
inputs=[query, time_slider], | |
outputs=[time_slider, time_explanation] | |
) | |
geo_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "geographical", v), | |
inputs=[query, map_plot], | |
outputs=[map_plot, geo_explanation] | |
) | |
style_zoom.click( | |
fn=lambda q, v: zoom_axis(q, "style", v), | |
inputs=[query, style_select], | |
outputs=[style_select, style_explanation] | |
) | |
return demo | |
if __name__ == "__main__": | |
explorer = ArtExplorer() | |
demo = explorer.create_interface() | |
demo.launch() |