Encyclopedia / app.py
baconnier's picture
Update app.py
9bf74ce verified
raw
history blame
9.2 kB
import os
import openai
import gradio as gr
import json
import plotly.graph_objects as go
from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response
class ArtExplorer:
def __init__(self):
self.client = openai.OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY")
)
self.current_state = {
"zoom_level": 0,
"selections": {}
}
def create_map(self, locations):
"""Create a Plotly map figure from location data"""
if not locations:
locations = [{"name": "Paris", "lat": 48.8566, "lon": 2.3522}]
fig = go.Figure(go.Scattermapbox(
lat=[loc.get('lat') for loc in locations],
lon=[loc.get('lon') for loc in locations],
mode='markers',
marker=go.scattermapbox.Marker(size=10),
text=[loc.get('name') for loc in locations]
))
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(
center=dict(lat=48.8566, lon=2.3522),
zoom=4
),
margin=dict(r=0, t=0, l=0, b=0)
)
return fig
def clean_json_string(self, json_str: str) -> str:
"""Clean and prepare JSON string for parsing"""
# Remove any leading/trailing whitespace and newlines
json_str = json_str.strip()
# Remove any BOM or special characters at the start
if json_str.startswith('\ufeff'):
json_str = json_str[1:]
# Ensure it starts with {
if not json_str.startswith('{'):
start_idx = json_str.find('{')
if start_idx != -1:
json_str = json_str[start_idx:]
# Ensure it ends with }
end_idx = json_str.rfind('}')
if end_idx != -1:
json_str = json_str[:end_idx+1]
return json_str
def get_llm_response(self, query: str, zoom_context: dict = None) -> dict:
try:
print("\n=== Starting LLM Request ===")
print(f"Input query: {query}")
print(f"Zoom context: {zoom_context}")
current_zoom_states = {
"temporal": {"level": self.current_state["zoom_level"], "selection": ""},
"geographical": {"level": self.current_state["zoom_level"], "selection": ""},
"style": {"level": self.current_state["zoom_level"], "selection": ""},
"subject": {"level": self.current_state["zoom_level"], "selection": ""}
}
if zoom_context:
for key, value in zoom_context.items():
if key in current_zoom_states:
current_zoom_states[key]["selection"] = value
messages = [
{"role": "system", "content": "You are an expert art historian specializing in interactive exploration."},
{"role": "user", "content": CONTEXTUAL_ZOOM_PROMPT.format(
user_query=query,
current_zoom_states=json.dumps(current_zoom_states, indent=2)
)}
]
print("\nPrepared messages for LLM:")
print(json.dumps(messages, indent=2))
print("\nSending request to LLM...")
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.1,
max_tokens=2048
)
print("\nReceived raw response from LLM:")
print(response)
# Get the response content and clean it
content = response.choices[0].message.content
print("\nExtracted content from response:")
print(content)
# Clean the JSON string
cleaned_content = self.clean_json_string(content)
print("\nCleaned JSON string:")
print(cleaned_content)
try:
result = json.loads(cleaned_content)
print("\nSuccessfully parsed JSON:")
print(json.dumps(result, indent=2))
return result
except json.JSONDecodeError as e:
print(f"\nJSON parsing error: {str(e)}")
print(f"Failed content: {cleaned_content}")
return self.get_default_response()
except Exception as e:
print(f"\nError in LLM response: {str(e)}")
print(f"Full error details: {e.__class__.__name__}")
import traceback
print(traceback.format_exc())
return self.get_default_response()
def get_default_response(self):
return CONTEXTUAL_ZOOM_default_response
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# Art History Explorer")
with gr.Row():
query = gr.Textbox(
label="Enter your art history query",
placeholder="e.g., Napoleon wars, Renaissance Italy"
)
search_btn = gr.Button("Explore")
with gr.Row():
# Temporal axis
with gr.Column():
time_slider = gr.Slider(
minimum=1000,
maximum=2024,
label="Time Period",
interactive=True
)
time_explanation = gr.Markdown()
time_zoom = gr.Button("πŸ” Zoom Time Period")
# Geographical axis
with gr.Column():
map_plot = gr.Plot(label="Geographic Location")
geo_explanation = gr.Markdown()
geo_zoom = gr.Button("πŸ” Zoom Geography")
with gr.Row():
style_select = gr.Dropdown(
choices=["Classical", "Modern"],
multiselect=True,
label="Artistic Styles",
allow_custom_value=True
)
style_explanation = gr.Markdown()
style_zoom = gr.Button("πŸ” Zoom Styles")
def initial_search(query):
"""Handle the initial search query"""
config = self.get_llm_response(query)
temporal_config = config["axis_configurations"]["temporal"]["current_zoom"]
geographical_config = config["axis_configurations"]["geographical"]["current_zoom"]
style_config = config["axis_configurations"]["style"]["current_zoom"]
map_fig = self.create_map(geographical_config["locations"])
return (
temporal_config["range"],
map_fig,
style_config["options"],
temporal_config["explanation"],
geographical_config.get("explanation", ""),
style_config["explanation"]
)
def zoom_axis(query, axis_name, current_value):
"""Handle zoom events for any axis"""
self.current_state["zoom_level"] += 1
config = self.get_llm_response(
query,
zoom_context={axis_name: current_value}
)
axis_config = config["axis_configurations"][axis_name]["current_zoom"]
if axis_name == "temporal":
return (
axis_config["range"],
axis_config["explanation"]
)
elif axis_name == "geographical":
map_fig = self.create_map(axis_config["locations"])
return (
map_fig,
axis_config.get("explanation", "")
)
else: # style
return (
axis_config["options"],
axis_config["explanation"]
)
# Connect event handlers
search_btn.click(
fn=initial_search,
inputs=[query],
outputs=[
time_slider,
map_plot,
style_select,
time_explanation,
geo_explanation,
style_explanation
]
)
time_zoom.click(
fn=lambda q, v: zoom_axis(q, "temporal", v),
inputs=[query, time_slider],
outputs=[time_slider, time_explanation]
)
geo_zoom.click(
fn=lambda q, v: zoom_axis(q, "geographical", v),
inputs=[query, map_plot],
outputs=[map_plot, geo_explanation]
)
style_zoom.click(
fn=lambda q, v: zoom_axis(q, "style", v),
inputs=[query, style_select],
outputs=[style_select, style_explanation]
)
return demo
def main():
print("Starting initialization...")
explorer = ArtExplorer()
print("Created ArtExplorer instance")
demo = explorer.create_interface()
print("Created interface")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
if __name__ == "__main__":
main()