ouhenio's picture
Update app.py
8eec983 verified
raw
history blame
18.8 kB
import os
from queue import Queue
import json
import gradio as gr
import argilla as rg
from argilla.webhooks import webhook_listener
# Initialize Argilla client
client = rg.Argilla(
api_url=os.getenv("ARGILLA_API_URL"),
api_key=os.getenv("ARGILLA_API_KEY"),
)
# Get the webhook server
server = rg.get_webhook_server()
# Queue to store events for display
incoming_events = Queue()
# Dictionary to store annotation progress by country
annotation_progress = {
# Format will be:
# "country_code": {"count": 0, "percent": 0, "name": "Country Name"}
}
# Country mapping (ISO code to name)
COUNTRY_MAPPING = {
"MX": {"name": "Mexico", "target": 1000},
"AR": {"name": "Argentina", "target": 800},
"CO": {"name": "Colombia", "target": 700},
"CL": {"name": "Chile", "target": 600},
"PE": {"name": "Peru", "target": 600},
"ES": {"name": "Spain", "target": 1200},
"BR": {"name": "Brazil", "target": 1000},
"VE": {"name": "Venezuela", "target": 500},
"EC": {"name": "Ecuador", "target": 400},
"BO": {"name": "Bolivia", "target": 300},
"PY": {"name": "Paraguay", "target": 300},
"UY": {"name": "Uruguay", "target": 300},
"CR": {"name": "Costa Rica", "target": 250},
"PA": {"name": "Panama", "target": 250},
"DO": {"name": "Dominican Republic", "target": 300},
"GT": {"name": "Guatemala", "target": 250},
"HN": {"name": "Honduras", "target": 200},
"SV": {"name": "El Salvador", "target": 200},
"NI": {"name": "Nicaragua", "target": 200},
"CU": {"name": "Cuba", "target": 300}
}
# Initialize the annotation progress data
for country_code, data in COUNTRY_MAPPING.items():
annotation_progress[country_code] = {
"count": 0,
"percent": 0,
"name": data["name"],
"target": data["target"]
}
# Set up the webhook listener for response creation
@webhook_listener(events=["response.created"])
async def update_validation_space_on_answer(response, type, timestamp):
"""
Webhook listener that triggers when a new response is added to an answering space.
It will automatically update the corresponding validation space with the new response
and update the country's annotation progress.
"""
try:
# Store the event for display in the UI
incoming_events.put({"event": type, "timestamp": str(timestamp)})
# Get the record from the response
record = response.record
# Check if this is from an answering space
dataset_name = record.dataset.name
if not dataset_name.endswith("_responder_preguntas"):
print(f"Ignoring event from non-answering dataset: {dataset_name}")
return # Not an answering space, ignore
# Extract the country from the dataset name
country = dataset_name.replace("_responder_preguntas", "")
print(f"Processing response for country: {country}")
# Connect to the validation space
validation_dataset_name = f"{country}_validar_respuestas"
try:
validation_dataset = client.datasets(validation_dataset_name)
print(f"Found validation dataset: {validation_dataset_name}")
except Exception as e:
print(f"Error connecting to validation dataset: {e}")
# You would need to import the create_validation_space function
from build_space import create_validation_space
validation_dataset = create_validation_space(country)
response_dict = response.to_dict()
answer = response_dict["values"]['text']['value']
# Get the user ID of the original responder
original_user_id = str(response.user_id)
# Create a validation record with the correct attribute
validation_record = {
"question": record.fields["question"],
"answer": answer,
"metadata": {
"original_responder_id": original_user_id,
"original_dataset": dataset_name
}
}
# Add the record to the validation space
validation_dataset.records.log(records=[validation_record])
print(f"Added new response to validation space for {country}")
# Update the annotation progress
# Get the country code from the country name or substring
country_code = None
for code, data in COUNTRY_MAPPING.items():
if data["name"].lower() in country.lower():
country_code = code
break
if country_code and country_code in annotation_progress:
# Increment the count
annotation_progress[country_code]["count"] += 1
# Update the percentage
target = annotation_progress[country_code]["target"]
count = annotation_progress[country_code]["count"]
percent = min(100, int((count / target) * 100))
annotation_progress[country_code]["percent"] = percent
# Update event queue with progress information
incoming_events.put({
"event": "progress_update",
"country": annotation_progress[country_code]["name"],
"count": count,
"percent": percent
})
except Exception as e:
print(f"Error in webhook handler: {e}")
# Store the error in the queue for display
incoming_events.put({"event": "error", "error": str(e)})
# Function to read the next event from the queue and update UI elements
def read_next_event():
if not incoming_events.empty():
event = incoming_events.get()
return event
return {}
# Function to get the current annotation progress data
def get_annotation_progress():
return json.dumps(annotation_progress)
# D3.js map visualization HTML template
def create_map_html(progress_data):
return f"""
<!DOCTYPE html>
<html>
<head>
<script src="https://d3js.org/d3.v7.min.js"></script>
<script src="https://d3js.org/d3-geo.v2.min.js"></script>
<script src="https://d3js.org/topojson.v3.min.js"></script>
<style>
.country {{
stroke: #f32b7b;
stroke-width: 1px;
}}
.country:hover {{
stroke: #4a1942;
stroke-width: 2px;
cursor: pointer;
}}
.tooltip {{
position: absolute;
background-color: rgba(0, 0, 0, 0.8);
border-radius: 5px;
padding: 8px;
color: white;
font-size: 12px;
pointer-events: none;
opacity: 0;
transition: opacity 0.3s;
}}
text {{
fill: #f1f5f9;
font-family: sans-serif;
}}
.legend-text {{
fill: #94a3b8;
font-size: 10px;
}}
</style>
</head>
<body style="margin:0; background-color:#111;">
<div id="map-container" style="width:100%; height:600px; position:relative;"></div>
<div id="tooltip" class="tooltip"></div>
<script>
// The progress data passed from Python
const progressData = {progress_data};
// Set up the SVG container
const width = document.getElementById('map-container').clientWidth;
const height = 600;
const svg = d3.select("#map-container")
.append("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${{width}} ${{height}}`)
.style("background-color", "#111");
// Define color scale
const colorScale = d3.scaleLinear()
.domain([0, 100])
.range(["#4a1942", "#f32b7b"]);
// Set up projection focused on Latin America and Spain
const projection = d3.geoMercator()
.center([-60, 0])
.scale(width / 5)
.translate([width / 2, height / 2]);
const path = d3.geoPath().projection(projection);
// Tooltip setup
const tooltip = d3.select("#tooltip");
// Load the world GeoJSON data
d3.json("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson")
.then(data => {{
// Filter for Latin American countries and Spain
const relevantCountryCodes = Object.keys(progressData);
// Draw the map
svg.selectAll("path")
.data(data.features)
.enter()
.append("path")
.attr("d", path)
.attr("class", "country")
.attr("fill", d => {{
// Get the ISO code from the properties
const iso = d.properties.iso_a2;
if (progressData[iso]) {{
return colorScale(progressData[iso].percent);
}}
return "#2d3748"; // Default gray for non-tracked countries
}})
.on("mouseover", function(event, d) {{
const iso = d.properties.iso_a2;
if (progressData[iso]) {{
tooltip.style("opacity", 1)
.html(`
<strong>${{progressData[iso].name}}</strong><br/>
Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/>
Completion: ${{progressData[iso].percent}}%
`);
}}
}})
.on("mousemove", function(event) {{
tooltip.style("left", (event.pageX + 15) + "px")
.style("top", (event.pageY + 15) + "px");
}})
.on("mouseout", function() {{
tooltip.style("opacity", 0);
}});
// Add legend
const legendWidth = Math.min(width - 40, 200);
const legendHeight = 15;
const legendX = width - legendWidth - 20;
const legend = svg.append("g")
.attr("class", "legend")
.attr("transform", `translate(${{legendX}}, 30)`);
// Create gradient for legend
const defs = svg.append("defs");
const gradient = defs.append("linearGradient")
.attr("id", "dataGradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "100%")
.attr("y2", "0%");
gradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", "#4a1942");
gradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", "#f32b7b");
// Add legend title
legend.append("text")
.attr("x", legendWidth / 2)
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("font-size", "12px")
.text("Annotation Progress");
// Add legend rectangle
legend.append("rect")
.attr("width", legendWidth)
.attr("height", legendHeight)
.attr("rx", 2)
.attr("ry", 2)
.style("fill", "url(#dataGradient)");
// Add legend labels
legend.append("text")
.attr("class", "legend-text")
.attr("x", 0)
.attr("y", legendHeight + 15)
.attr("text-anchor", "start")
.text("0%");
legend.append("text")
.attr("class", "legend-text")
.attr("x", legendWidth / 2)
.attr("y", legendHeight + 15)
.attr("text-anchor", "middle")
.text("50%");
legend.append("text")
.attr("class", "legend-text")
.attr("x", legendWidth)
.attr("y", legendHeight + 15)
.attr("text-anchor", "end")
.text("100%");
}});
// Handle window resize
window.addEventListener('resize', () => {{
const width = document.getElementById('map-container').clientWidth;
// Update SVG dimensions
d3.select("svg")
.attr("width", width)
.attr("viewBox", `0 0 ${{width}} ${{height}}`);
// Update projection
projection.scale(width / 5)
.translate([width / 2, height / 2]);
// Update paths
d3.selectAll("path").attr("d", path);
// Update legend position
const legendWidth = Math.min(width - 40, 200);
const legendX = width - legendWidth - 20;
d3.select(".legend")
.attr("transform", `translate(${{legendX}}, 30)`);
}});
</script>
</body>
</html>
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) as demo:
argilla_server = client.http_client.base_url if hasattr(client, 'http_client') else "Not connected"
with gr.Row():
gr.Markdown(f"""
# Argilla Annotation Progress Map
### Connected to Argilla server: {argilla_server}
This dashboard visualizes annotation progress across Latin America and Spain.
The webhook listens for `response.created` events from datasets ending with `_responder_preguntas`.
""")
with gr.Row():
with gr.Column(scale=2):
# Map visualization
map_html = gr.HTML(create_map_html(json.dumps(annotation_progress)), label="Annotation Progress Map")
# Stats section
with gr.Accordion("Overall Statistics", open=False):
total_docs = gr.Number(value=0, label="Total Documents Collected")
avg_completion = gr.Number(value=0, label="Average Completion (%)")
countries_over_50 = gr.Number(value=0, label="Countries Over 50% Complete")
with gr.Column(scale=1):
# Recent events log
events_json = gr.JSON(label="Recent Events", value={})
# Country details
with gr.Accordion("Country Details", open=True):
country_selector = gr.Dropdown(
choices=[f"{data['name']} ({code})" for code, data in COUNTRY_MAPPING.items()],
label="Select Country"
)
country_progress = gr.JSON(label="Country Progress", value={})
# Functions to update the UI
def update_map():
progress_json = json.dumps(annotation_progress)
return create_map_html(progress_json)
def update_stats():
total = sum(data["count"] for data in annotation_progress.values())
percentages = [data["percent"] for data in annotation_progress.values()]
avg = sum(percentages) / len(percentages) if percentages else 0
countries_50_plus = sum(1 for p in percentages if p >= 50)
return total, avg, countries_50_plus
def update_country_details(country_selection):
if not country_selection:
return {}
# Extract the country code from the selection (format: "Country Name (CODE)")
code = country_selection.split("(")[-1].replace(")", "").strip()
if code in annotation_progress:
return annotation_progress[code]
return {}
# Simple event processing functions for better compatibility
def update_events():
return read_next_event()
def update_all_stats():
return update_stats()
# Use separate timers for each component for better compatibility
gr.Timer(1, active=True).tick(update_events, outputs=events_json)
gr.Timer(5, active=True).tick(update_all_stats, outputs=[total_docs, avg_completion, countries_over_50])
country_selector.change(
fn=update_country_details,
inputs=[country_selector],
outputs=[country_progress]
)
# Use refresh button instead of timer for map updates (more compatible across versions)
refresh_btn = gr.Button("Refresh Map")
refresh_btn.click(
fn=update_map,
inputs=None,
outputs=[map_html]
)
# Mount the Gradio app to the FastAPI server
gr.mount_gradio_app(server, demo, path="/")
# Start the FastAPI server
if __name__ == "__main__":
import uvicorn
# Initialize with some sample data
for code in ["MX", "AR", "CO", "ES"]:
annotation_progress[code]["count"] = int(annotation_progress[code]["target"] * 0.3)
annotation_progress[code]["percent"] = 30
annotation_progress["BR"]["count"] = int(annotation_progress["BR"]["target"] * 0.5)
annotation_progress["BR"]["percent"] = 50
annotation_progress["CL"]["count"] = int(annotation_progress["CL"]["target"] * 0.7)
annotation_progress["CL"]["percent"] = 70
# Start the server
uvicorn.run(server, host="0.0.0.0", port=7860)