Spaces:
Runtime error
Runtime error
import os | |
from queue import Queue | |
import json | |
import gradio as gr | |
import argilla as rg | |
from argilla.webhooks import webhook_listener | |
# Initialize Argilla client | |
client = rg.Argilla( | |
api_url=os.getenv("ARGILLA_API_URL"), | |
api_key=os.getenv("ARGILLA_API_KEY"), | |
) | |
# Get the webhook server | |
server = rg.get_webhook_server() | |
# Queue to store events for display | |
incoming_events = Queue() | |
# Dictionary to store annotation progress by country | |
annotation_progress = { | |
# Format will be: | |
# "country_code": {"count": 0, "percent": 0, "name": "Country Name"} | |
} | |
# Country mapping (ISO code to name) | |
COUNTRY_MAPPING = { | |
"MX": {"name": "Mexico", "target": 1000}, | |
"AR": {"name": "Argentina", "target": 800}, | |
"CO": {"name": "Colombia", "target": 700}, | |
"CL": {"name": "Chile", "target": 600}, | |
"PE": {"name": "Peru", "target": 600}, | |
"ES": {"name": "Spain", "target": 1200}, | |
"BR": {"name": "Brazil", "target": 1000}, | |
"VE": {"name": "Venezuela", "target": 500}, | |
"EC": {"name": "Ecuador", "target": 400}, | |
"BO": {"name": "Bolivia", "target": 300}, | |
"PY": {"name": "Paraguay", "target": 300}, | |
"UY": {"name": "Uruguay", "target": 300}, | |
"CR": {"name": "Costa Rica", "target": 250}, | |
"PA": {"name": "Panama", "target": 250}, | |
"DO": {"name": "Dominican Republic", "target": 300}, | |
"GT": {"name": "Guatemala", "target": 250}, | |
"HN": {"name": "Honduras", "target": 200}, | |
"SV": {"name": "El Salvador", "target": 200}, | |
"NI": {"name": "Nicaragua", "target": 200}, | |
"CU": {"name": "Cuba", "target": 300} | |
} | |
# Initialize the annotation progress data | |
for country_code, data in COUNTRY_MAPPING.items(): | |
annotation_progress[country_code] = { | |
"count": 0, | |
"percent": 0, | |
"name": data["name"], | |
"target": data["target"] | |
} | |
# Set up the webhook listener for response creation | |
async def update_validation_space_on_answer(response, type, timestamp): | |
""" | |
Webhook listener that triggers when a new response is added to an answering space. | |
It will automatically update the corresponding validation space with the new response | |
and update the country's annotation progress. | |
""" | |
try: | |
# Store the event for display in the UI | |
incoming_events.put({"event": type, "timestamp": str(timestamp)}) | |
# Get the record from the response | |
record = response.record | |
# Check if this is from an answering space | |
dataset_name = record.dataset.name | |
if not dataset_name.endswith("_responder_preguntas"): | |
print(f"Ignoring event from non-answering dataset: {dataset_name}") | |
return # Not an answering space, ignore | |
# Extract the country from the dataset name | |
country = dataset_name.replace("_responder_preguntas", "") | |
print(f"Processing response for country: {country}") | |
# Connect to the validation space | |
validation_dataset_name = f"{country}_validar_respuestas" | |
try: | |
validation_dataset = client.datasets(validation_dataset_name) | |
print(f"Found validation dataset: {validation_dataset_name}") | |
except Exception as e: | |
print(f"Error connecting to validation dataset: {e}") | |
# You would need to import the create_validation_space function | |
from build_space import create_validation_space | |
validation_dataset = create_validation_space(country) | |
response_dict = response.to_dict() | |
answer = response_dict["values"]['text']['value'] | |
# Get the user ID of the original responder | |
original_user_id = str(response.user_id) | |
# Create a validation record with the correct attribute | |
validation_record = { | |
"question": record.fields["question"], | |
"answer": answer, | |
"metadata": { | |
"original_responder_id": original_user_id, | |
"original_dataset": dataset_name | |
} | |
} | |
# Add the record to the validation space | |
validation_dataset.records.log(records=[validation_record]) | |
print(f"Added new response to validation space for {country}") | |
# Update the annotation progress | |
# Get the country code from the country name or substring | |
country_code = None | |
for code, data in COUNTRY_MAPPING.items(): | |
if data["name"].lower() in country.lower(): | |
country_code = code | |
break | |
if country_code and country_code in annotation_progress: | |
# Increment the count | |
annotation_progress[country_code]["count"] += 1 | |
# Update the percentage | |
target = annotation_progress[country_code]["target"] | |
count = annotation_progress[country_code]["count"] | |
percent = min(100, int((count / target) * 100)) | |
annotation_progress[country_code]["percent"] = percent | |
# Update event queue with progress information | |
incoming_events.put({ | |
"event": "progress_update", | |
"country": annotation_progress[country_code]["name"], | |
"count": count, | |
"percent": percent | |
}) | |
except Exception as e: | |
print(f"Error in webhook handler: {e}") | |
# Store the error in the queue for display | |
incoming_events.put({"event": "error", "error": str(e)}) | |
# Function to read the next event from the queue and update UI elements | |
def read_next_event(): | |
if not incoming_events.empty(): | |
event = incoming_events.get() | |
return event | |
return {} | |
# Function to get the current annotation progress data | |
def get_annotation_progress(): | |
return json.dumps(annotation_progress) | |
# D3.js map visualization HTML template | |
def create_map_html(progress_data): | |
return f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<script src="https://d3js.org/d3.v7.min.js"></script> | |
<script src="https://d3js.org/d3-geo.v2.min.js"></script> | |
<script src="https://d3js.org/topojson.v3.min.js"></script> | |
<style> | |
.country {{ | |
stroke: #f32b7b; | |
stroke-width: 1px; | |
}} | |
.country:hover {{ | |
stroke: #4a1942; | |
stroke-width: 2px; | |
cursor: pointer; | |
}} | |
.tooltip {{ | |
position: absolute; | |
background-color: rgba(0, 0, 0, 0.8); | |
border-radius: 5px; | |
padding: 8px; | |
color: white; | |
font-size: 12px; | |
pointer-events: none; | |
opacity: 0; | |
transition: opacity 0.3s; | |
}} | |
text {{ | |
fill: #f1f5f9; | |
font-family: sans-serif; | |
}} | |
.legend-text {{ | |
fill: #94a3b8; | |
font-size: 10px; | |
}} | |
</style> | |
</head> | |
<body style="margin:0; background-color:#111;"> | |
<div id="map-container" style="width:100%; height:600px; position:relative;"></div> | |
<div id="tooltip" class="tooltip"></div> | |
<script> | |
// The progress data passed from Python | |
const progressData = {progress_data}; | |
// Set up the SVG container | |
const width = document.getElementById('map-container').clientWidth; | |
const height = 600; | |
const svg = d3.select("#map-container") | |
.append("svg") | |
.attr("width", width) | |
.attr("height", height) | |
.attr("viewBox", `0 0 ${{width}} ${{height}}`) | |
.style("background-color", "#111"); | |
// Define color scale | |
const colorScale = d3.scaleLinear() | |
.domain([0, 100]) | |
.range(["#4a1942", "#f32b7b"]); | |
// Set up projection focused on Latin America and Spain | |
const projection = d3.geoMercator() | |
.center([-60, 0]) | |
.scale(width / 5) | |
.translate([width / 2, height / 2]); | |
const path = d3.geoPath().projection(projection); | |
// Tooltip setup | |
const tooltip = d3.select("#tooltip"); | |
// Load the world GeoJSON data | |
d3.json("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson") | |
.then(data => {{ | |
// Filter for Latin American countries and Spain | |
const relevantCountryCodes = Object.keys(progressData); | |
// Draw the map | |
svg.selectAll("path") | |
.data(data.features) | |
.enter() | |
.append("path") | |
.attr("d", path) | |
.attr("class", "country") | |
.attr("fill", d => {{ | |
// Get the ISO code from the properties | |
const iso = d.properties.iso_a2; | |
if (progressData[iso]) {{ | |
return colorScale(progressData[iso].percent); | |
}} | |
return "#2d3748"; // Default gray for non-tracked countries | |
}}) | |
.on("mouseover", function(event, d) {{ | |
const iso = d.properties.iso_a2; | |
if (progressData[iso]) {{ | |
tooltip.style("opacity", 1) | |
.html(` | |
<strong>${{progressData[iso].name}}</strong><br/> | |
Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/> | |
Completion: ${{progressData[iso].percent}}% | |
`); | |
}} | |
}}) | |
.on("mousemove", function(event) {{ | |
tooltip.style("left", (event.pageX + 15) + "px") | |
.style("top", (event.pageY + 15) + "px"); | |
}}) | |
.on("mouseout", function() {{ | |
tooltip.style("opacity", 0); | |
}}); | |
// Add legend | |
const legendWidth = Math.min(width - 40, 200); | |
const legendHeight = 15; | |
const legendX = width - legendWidth - 20; | |
const legend = svg.append("g") | |
.attr("class", "legend") | |
.attr("transform", `translate(${{legendX}}, 30)`); | |
// Create gradient for legend | |
const defs = svg.append("defs"); | |
const gradient = defs.append("linearGradient") | |
.attr("id", "dataGradient") | |
.attr("x1", "0%") | |
.attr("y1", "0%") | |
.attr("x2", "100%") | |
.attr("y2", "0%"); | |
gradient.append("stop") | |
.attr("offset", "0%") | |
.attr("stop-color", "#4a1942"); | |
gradient.append("stop") | |
.attr("offset", "100%") | |
.attr("stop-color", "#f32b7b"); | |
// Add legend title | |
legend.append("text") | |
.attr("x", legendWidth / 2) | |
.attr("y", -10) | |
.attr("text-anchor", "middle") | |
.attr("font-size", "12px") | |
.text("Annotation Progress"); | |
// Add legend rectangle | |
legend.append("rect") | |
.attr("width", legendWidth) | |
.attr("height", legendHeight) | |
.attr("rx", 2) | |
.attr("ry", 2) | |
.style("fill", "url(#dataGradient)"); | |
// Add legend labels | |
legend.append("text") | |
.attr("class", "legend-text") | |
.attr("x", 0) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "start") | |
.text("0%"); | |
legend.append("text") | |
.attr("class", "legend-text") | |
.attr("x", legendWidth / 2) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "middle") | |
.text("50%"); | |
legend.append("text") | |
.attr("class", "legend-text") | |
.attr("x", legendWidth) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "end") | |
.text("100%"); | |
}}); | |
// Handle window resize | |
window.addEventListener('resize', () => {{ | |
const width = document.getElementById('map-container').clientWidth; | |
// Update SVG dimensions | |
d3.select("svg") | |
.attr("width", width) | |
.attr("viewBox", `0 0 ${{width}} ${{height}}`); | |
// Update projection | |
projection.scale(width / 5) | |
.translate([width / 2, height / 2]); | |
// Update paths | |
d3.selectAll("path").attr("d", path); | |
// Update legend position | |
const legendWidth = Math.min(width - 40, 200); | |
const legendX = width - legendWidth - 20; | |
d3.select(".legend") | |
.attr("transform", `translate(${{legendX}}, 30)`); | |
}}); | |
</script> | |
</body> | |
</html> | |
""" | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) as demo: | |
argilla_server = client.http_client.base_url if hasattr(client, 'http_client') else "Not connected" | |
with gr.Row(): | |
gr.Markdown(f""" | |
# Argilla Annotation Progress Map | |
### Connected to Argilla server: {argilla_server} | |
This dashboard visualizes annotation progress across Latin America and Spain. | |
The webhook listens for `response.created` events from datasets ending with `_responder_preguntas`. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Map visualization | |
map_html = gr.HTML(create_map_html(json.dumps(annotation_progress)), label="Annotation Progress Map") | |
# Stats section | |
with gr.Accordion("Overall Statistics", open=False): | |
total_docs = gr.Number(value=0, label="Total Documents Collected") | |
avg_completion = gr.Number(value=0, label="Average Completion (%)") | |
countries_over_50 = gr.Number(value=0, label="Countries Over 50% Complete") | |
with gr.Column(scale=1): | |
# Recent events log | |
events_json = gr.JSON(label="Recent Events", value={}) | |
# Country details | |
with gr.Accordion("Country Details", open=True): | |
country_selector = gr.Dropdown( | |
choices=[f"{data['name']} ({code})" for code, data in COUNTRY_MAPPING.items()], | |
label="Select Country" | |
) | |
country_progress = gr.JSON(label="Country Progress", value={}) | |
# Functions to update the UI | |
def update_map(): | |
progress_json = json.dumps(annotation_progress) | |
return create_map_html(progress_json) | |
def update_stats(): | |
total = sum(data["count"] for data in annotation_progress.values()) | |
percentages = [data["percent"] for data in annotation_progress.values()] | |
avg = sum(percentages) / len(percentages) if percentages else 0 | |
countries_50_plus = sum(1 for p in percentages if p >= 50) | |
return total, avg, countries_50_plus | |
def update_country_details(country_selection): | |
if not country_selection: | |
return {} | |
# Extract the country code from the selection (format: "Country Name (CODE)") | |
code = country_selection.split("(")[-1].replace(")", "").strip() | |
if code in annotation_progress: | |
return annotation_progress[code] | |
return {} | |
# Simple event processing functions for better compatibility | |
def update_events(): | |
return read_next_event() | |
def update_all_stats(): | |
return update_stats() | |
# Use separate timers for each component for better compatibility | |
gr.Timer(1, active=True).tick(update_events, outputs=events_json) | |
gr.Timer(5, active=True).tick(update_all_stats, outputs=[total_docs, avg_completion, countries_over_50]) | |
country_selector.change( | |
fn=update_country_details, | |
inputs=[country_selector], | |
outputs=[country_progress] | |
) | |
# Use refresh button instead of timer for map updates (more compatible across versions) | |
refresh_btn = gr.Button("Refresh Map") | |
refresh_btn.click( | |
fn=update_map, | |
inputs=None, | |
outputs=[map_html] | |
) | |
# Mount the Gradio app to the FastAPI server | |
gr.mount_gradio_app(server, demo, path="/") | |
# Start the FastAPI server | |
if __name__ == "__main__": | |
import uvicorn | |
# Initialize with some sample data | |
for code in ["MX", "AR", "CO", "ES"]: | |
annotation_progress[code]["count"] = int(annotation_progress[code]["target"] * 0.3) | |
annotation_progress[code]["percent"] = 30 | |
annotation_progress["BR"]["count"] = int(annotation_progress["BR"]["target"] * 0.5) | |
annotation_progress["BR"]["percent"] = 50 | |
annotation_progress["CL"]["count"] = int(annotation_progress["CL"]["target"] * 0.7) | |
annotation_progress["CL"]["percent"] = 70 | |
# Start the server | |
uvicorn.run(server, host="0.0.0.0", port=7860) |