ouhenio commited on
Commit
766b697
·
verified ·
1 Parent(s): f8842a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +563 -417
app.py CHANGED
@@ -4,464 +4,610 @@ import json
4
  import gradio as gr
5
  import argilla as rg
6
  from argilla.webhooks import webhook_listener
 
 
 
7
 
8
- # Initialize Argilla client
9
- client = rg.Argilla(
10
- api_url=os.getenv("ARGILLA_API_URL"),
11
- api_key=os.getenv("ARGILLA_API_KEY"),
12
- )
13
 
14
- # Get the webhook server
15
- server = rg.get_webhook_server()
 
16
 
17
- # Queue to store events for display
18
- incoming_events = Queue()
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Dictionary to store annotation progress by country
21
- annotation_progress = {
22
- # Format will be:
23
- # "country_code": {"count": 0, "percent": 0, "name": "Country Name"}
24
- }
 
 
 
 
25
 
26
- # Country mapping (ISO code to name)
27
- COUNTRY_MAPPING = {
28
- "MX": {"name": "Mexico", "target": 1000},
29
- "AR": {"name": "Argentina", "target": 800},
30
- "CO": {"name": "Colombia", "target": 700},
31
- "CL": {"name": "Chile", "target": 600},
32
- "PE": {"name": "Peru", "target": 600},
33
- "ES": {"name": "Spain", "target": 1200},
34
- "BR": {"name": "Brazil", "target": 1000},
35
- "VE": {"name": "Venezuela", "target": 500},
36
- "EC": {"name": "Ecuador", "target": 400},
37
- "BO": {"name": "Bolivia", "target": 300},
38
- "PY": {"name": "Paraguay", "target": 300},
39
- "UY": {"name": "Uruguay", "target": 300},
40
- "CR": {"name": "Costa Rica", "target": 250},
41
- "PA": {"name": "Panama", "target": 250},
42
- "DO": {"name": "Dominican Republic", "target": 300},
43
- "GT": {"name": "Guatemala", "target": 250},
44
- "HN": {"name": "Honduras", "target": 200},
45
- "SV": {"name": "El Salvador", "target": 200},
46
- "NI": {"name": "Nicaragua", "target": 200},
47
- "CU": {"name": "Cuba", "target": 300}
48
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Initialize the annotation progress data
51
- for country_code, data in COUNTRY_MAPPING.items():
52
- annotation_progress[country_code] = {
53
- "count": 0,
54
- "percent": 0,
55
- "name": data["name"],
56
- "target": data["target"]
57
- }
58
 
59
- # Set up the webhook listener for response creation
60
- @webhook_listener(events=["response.created"])
61
- async def update_annotation_progress(response, type, timestamp):
62
- """
63
- Webhook listener that triggers when a new response is added to an Argilla dataset.
64
- It will update the annotation progress for the corresponding country.
65
- """
66
- try:
67
- # Store the event for display in the UI
68
- incoming_events.put({"event": type, "timestamp": str(timestamp)})
69
-
70
- # Get the record from the response
71
- record = response.record
72
-
73
- # Get dataset name
74
- dataset_name = record.dataset.name
75
- print(f"Processing response for dataset: {dataset_name}")
76
-
77
- # Try to determine the country from the dataset name
78
- country_code = None
79
- for code, data in COUNTRY_MAPPING.items():
80
- country_name = data["name"].lower()
81
- if country_name in dataset_name.lower():
82
- country_code = code
83
- break
84
-
85
- # If we found a matching country, update its progress
86
- if country_code and country_code in annotation_progress:
87
- # Increment the count
88
- annotation_progress[country_code]["count"] += 1
89
-
90
- # Update the percentage
91
- target = annotation_progress[country_code]["target"]
92
- count = annotation_progress[country_code]["count"]
93
- percent = min(100, int((count / target) * 100))
94
- annotation_progress[country_code]["percent"] = percent
95
-
96
- # Update event queue with progress information
97
- incoming_events.put({
98
- "event": "progress_update",
99
- "country": annotation_progress[country_code]["name"],
100
- "count": count,
101
- "percent": percent
102
- })
103
- print(f"Updated progress for {annotation_progress[country_code]['name']}: {percent}%")
104
-
105
- except Exception as e:
106
- print(f"Error in webhook handler: {e}")
107
- # Store the error in the queue for display
108
- incoming_events.put({"event": "error", "error": str(e)})
109
 
110
- # Function to read the next event from the queue
111
- def read_next_event():
112
- if not incoming_events.empty():
113
- event = incoming_events.get()
114
- return event
115
- return {}
116
 
117
- # Function to calculate overall statistics
118
- def update_stats():
119
- total = sum(data["count"] for data in annotation_progress.values())
120
- percentages = [data["percent"] for data in annotation_progress.values()]
121
- avg = sum(percentages) / len(percentages) if percentages else 0
122
- countries_50_plus = sum(1 for p in percentages if p >= 50)
 
 
 
 
 
 
123
 
124
- return total, avg, countries_50_plus
 
 
 
 
 
 
125
 
126
- # Create the map HTML container (without D3.js code)
127
- def create_map_html():
128
- return """
129
- <div id="map-container" style="width:100%; height:600px; position:relative; background-color:#111;">
130
- <div style="display:flex; justify-content:center; align-items:center; height:100%; color:white; font-family:sans-serif;">
131
- Loading map visualization...
132
- </div>
133
- </div>
134
- <div id="tooltip" style="position:absolute; background-color:rgba(0,0,0,0.8); border-radius:5px; padding:8px; color:white; font-size:12px; pointer-events:none; opacity:0; transition:opacity 0.3s;"></div>
135
- """
 
 
 
 
136
 
137
- # Create D3.js script that will be loaded via Gradio's JavaScript execution
138
- def create_d3_script(progress_data):
139
- return f"""
140
- async () => {{
141
- // Load D3.js modules
142
- const script1 = document.createElement("script");
143
- script1.src = "https://cdn.jsdelivr.net/npm/d3@7";
144
- document.head.appendChild(script1);
145
-
146
- // Wait for D3 to load
147
- await new Promise(resolve => {{
148
- script1.onload = resolve;
149
- }});
150
-
151
- console.log("D3 loaded successfully");
152
-
153
- // Load topojson
154
- const script2 = document.createElement("script");
155
- script2.src = "https://cdn.jsdelivr.net/npm/topojson@3";
156
- document.head.appendChild(script2);
157
-
158
- await new Promise(resolve => {{
159
- script2.onload = resolve;
160
- }});
161
-
162
- console.log("TopoJSON loaded successfully");
163
-
164
- // The progress data passed from Python
165
- const progressData = {progress_data};
166
-
167
- // Set up the SVG container
168
- const mapContainer = document.getElementById('map-container');
169
- mapContainer.innerHTML = ''; // Clear loading message
170
-
171
- const width = mapContainer.clientWidth;
172
- const height = 600;
173
-
174
- const svg = d3.select("#map-container")
175
- .append("svg")
176
- .attr("width", width)
177
- .attr("height", height)
178
- .attr("viewBox", `0 0 ${{width}} ${{height}}`)
179
- .style("background-color", "#111");
180
-
181
- // Define color scale
182
- const colorScale = d3.scaleLinear()
183
- .domain([0, 100])
184
- .range(["#4a1942", "#f32b7b"]);
185
-
186
- // Set up projection focused on Latin America and Spain
187
- const projection = d3.geoMercator()
188
- .center([-60, 0])
189
- .scale(width / 5)
190
- .translate([width / 2, height / 2]);
191
 
192
- const path = d3.geoPath().projection(projection);
193
-
194
- // Tooltip setup
195
- const tooltip = d3.select("#tooltip");
196
-
197
- // Load the world GeoJSON data
198
- const response = await fetch("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson");
199
- const data = await response.json();
200
-
201
- // Draw the map
202
- svg.selectAll("path")
203
- .data(data.features)
204
- .enter()
205
- .append("path")
206
- .attr("d", path)
207
- .attr("stroke", "#f32b7b")
208
- .attr("stroke-width", 1)
209
- .attr("fill", d => {{
210
- // Get the ISO code from the properties
211
- const iso = d.properties.iso_a2;
212
-
213
- if (progressData[iso]) {{
214
- return colorScale(progressData[iso].percent);
215
- }}
216
- return "#2d3748"; // Default gray for non-tracked countries
217
- }})
218
- .on("mouseover", function(event, d) {{
219
- const iso = d.properties.iso_a2;
220
-
221
- d3.select(this)
222
- .attr("stroke", "#4a1942")
223
- .attr("stroke-width", 2);
224
-
225
- if (progressData[iso]) {{
226
- tooltip.style("opacity", 1)
227
- .style("left", (event.pageX + 15) + "px")
228
- .style("top", (event.pageY + 15) + "px")
229
- .html(`
230
- <strong>${{progressData[iso].name}}</strong><br/>
231
- Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/>
232
- Completion: ${{progressData[iso].percent}}%
233
- `);
234
- }}
235
- }})
236
- .on("mousemove", function(event) {{
237
- tooltip.style("left", (event.pageX + 15) + "px")
238
- .style("top", (event.pageY + 15) + "px");
239
- }})
240
- .on("mouseout", function() {{
241
- d3.select(this)
242
- .attr("stroke", "#f32b7b")
243
- .attr("stroke-width", 1);
244
-
245
- tooltip.style("opacity", 0);
246
  }});
247
 
248
- // Add legend
249
- const legendWidth = Math.min(width - 40, 200);
250
- const legendHeight = 15;
251
- const legendX = width - legendWidth - 20;
252
-
253
- const legend = svg.append("g")
254
- .attr("class", "legend")
255
- .attr("transform", `translate(${{legendX}}, 30)`);
256
-
257
- // Create gradient for legend
258
- const defs = svg.append("defs");
259
- const gradient = defs.append("linearGradient")
260
- .attr("id", "dataGradient")
261
- .attr("x1", "0%")
262
- .attr("y1", "0%")
263
- .attr("x2", "100%")
264
- .attr("y2", "0%");
265
-
266
- gradient.append("stop")
267
- .attr("offset", "0%")
268
- .attr("stop-color", "#4a1942");
269
-
270
- gradient.append("stop")
271
- .attr("offset", "100%")
272
- .attr("stop-color", "#f32b7b");
273
-
274
- // Add legend title
275
- legend.append("text")
276
- .attr("x", legendWidth / 2)
277
- .attr("y", -10)
278
- .attr("text-anchor", "middle")
279
- .attr("font-size", "12px")
280
- .attr("fill", "#f1f5f9")
281
- .text("Annotation Progress");
282
-
283
- // Add legend rectangle
284
- legend.append("rect")
285
- .attr("width", legendWidth)
286
- .attr("height", legendHeight)
287
- .attr("rx", 2)
288
- .attr("ry", 2)
289
- .style("fill", "url(#dataGradient)");
290
-
291
- // Add legend labels
292
- legend.append("text")
293
- .attr("x", 0)
294
- .attr("y", legendHeight + 15)
295
- .attr("text-anchor", "start")
296
- .attr("font-size", "10px")
297
- .attr("fill", "#94a3b8")
298
- .text("0%");
299
-
300
- legend.append("text")
301
- .attr("x", legendWidth / 2)
302
- .attr("y", legendHeight + 15)
303
- .attr("text-anchor", "middle")
304
- .attr("font-size", "10px")
305
- .attr("fill", "#94a3b8")
306
- .text("50%");
307
-
308
- legend.append("text")
309
- .attr("x", legendWidth)
310
- .attr("y", legendHeight + 15)
311
- .attr("text-anchor", "end")
312
- .attr("font-size", "10px")
313
- .attr("fill", "#94a3b8")
314
- .text("100%");
315
-
316
- // Handle window resize
317
- globalThis.resizeMap = () => {{
318
  const width = mapContainer.clientWidth;
 
319
 
320
- // Update SVG dimensions
321
- d3.select("svg")
322
  .attr("width", width)
323
- .attr("viewBox", `0 0 ${{width}} ${{height}}`);
 
 
 
 
 
 
 
324
 
325
- // Update projection
326
- projection.scale(width / 5)
 
 
327
  .translate([width / 2, height / 2]);
 
 
 
 
 
328
 
329
- // Update paths
330
- d3.selectAll("path").attr("d", path);
 
331
 
332
- // Update legend position
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  const legendWidth = Math.min(width - 40, 200);
 
334
  const legendX = width - legendWidth - 20;
335
 
336
- d3.select(".legend")
 
337
  .attr("transform", `translate(${{legendX}}, 30)`);
338
- }};
339
-
340
- window.addEventListener('resize', globalThis.resizeMap);
341
- }}
342
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- # Function to update the map data and trigger a reload
345
- def update_map():
346
- progress_json = json.dumps(annotation_progress)
347
- return progress_json
348
 
349
- # Create Gradio interface
350
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) as demo:
351
- argilla_server = client.http_client.base_url if hasattr(client, 'http_client') else "Not connected"
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
- with gr.Row():
354
- gr.Markdown(f"""
355
- # Latin America & Spain Annotation Progress Map
356
-
357
- ### Connected to Argilla server: {argilla_server}
358
-
359
- This dashboard visualizes annotation progress across Latin America and Spain.
360
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
- with gr.Row():
363
- with gr.Column(scale=2):
364
- # Map visualization - empty at first
365
- map_html = gr.HTML(create_map_html(), label="Annotation Progress Map")
 
 
 
 
366
 
367
- # Hidden element to store map data
368
- map_data = gr.JSON(value=json.dumps(annotation_progress), visible=False)
369
-
370
- with gr.Column(scale=1):
371
- # Recent events log
372
- events_json = gr.JSON(label="Recent Events", value={})
373
-
374
- # Overall statistics
375
- total_docs = gr.Number(value=0, label="Total Documents", interactive=False)
376
- avg_completion = gr.Number(value=0, label="Average Completion (%)", interactive=False)
377
- countries_over_50 = gr.Number(value=0, label="Countries Over 50%", interactive=False)
378
-
379
- # Country details
380
- country_selector = gr.Dropdown(
381
- choices=[f"{data['name']} ({code})" for code, data in COUNTRY_MAPPING.items()],
382
- label="Select Country"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  )
384
- country_progress = gr.JSON(label="Country Progress", value={})
385
-
386
- # Load the D3 script when data is updated
387
- def load_map_script(data):
388
- return None, create_d3_script(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- # Refresh button
391
- refresh_btn = gr.Button("Refresh Map")
392
- refresh_btn.click(
393
- fn=update_map,
394
- inputs=None,
395
- outputs=map_data
396
- )
397
-
398
- # When map_data is updated, reload the D3 script
399
- map_data.change(
400
- fn=load_map_script,
401
- inputs=map_data,
402
- outputs=[None, None],
403
- _js=create_d3_script(json.dumps(annotation_progress))
404
- )
405
 
406
- # Function to update country details
407
- def update_country_details(country_selection):
408
- if not country_selection:
409
- return {}
410
-
411
- # Extract the country code from the selection (format: "Country Name (CODE)")
412
- code = country_selection.split("(")[-1].replace(")", "").strip()
413
-
414
- if code in annotation_progress:
415
- return annotation_progress[code]
416
- return {}
417
 
418
- # Update country details when a country is selected
419
- country_selector.change(
420
- fn=update_country_details,
421
- inputs=[country_selector],
422
- outputs=[country_progress]
423
- )
424
 
425
- # Update events function for the timer
426
- def update_events():
427
- event = read_next_event()
428
-
429
- # Calculate stats
430
- stats = update_stats()
431
-
432
- # If this is a progress update, update the map data
433
- if event.get("event") == "progress_update":
434
- # This will indirectly trigger a map refresh through the change event
435
- return event, json.dumps(annotation_progress), stats[0], stats[1], stats[2]
436
-
437
- return event, None, stats[0], stats[1], stats[2]
438
 
439
- # Make final updates to Gradio demo
440
- demo.load(None, None, None, _js=create_d3_script(json.dumps(annotation_progress)))
441
 
442
- # Use timer to check for new events and update stats
443
- gr.Timer(1, active=True).tick(
444
- update_events,
445
- outputs=[events_json, map_data, total_docs, avg_completion, countries_over_50]
446
- )
447
-
448
- # Mount the Gradio app to the FastAPI server
449
- gr.mount_gradio_app(server, demo, path="/")
450
-
451
- # Start the FastAPI server
452
  if __name__ == "__main__":
453
  import uvicorn
454
 
455
- # Initialize with some sample data
456
- for code in ["MX", "AR", "CO", "ES"]:
457
- annotation_progress[code]["count"] = int(annotation_progress[code]["target"] * 0.3)
458
- annotation_progress[code]["percent"] = 30
459
-
460
- annotation_progress["BR"]["count"] = int(annotation_progress["BR"]["target"] * 0.5)
461
- annotation_progress["BR"]["percent"] = 50
462
-
463
- annotation_progress["CL"]["count"] = int(annotation_progress["CL"]["target"] * 0.7)
464
- annotation_progress["CL"]["percent"] = 70
465
 
466
  # Start the server
467
  uvicorn.run(server, host="0.0.0.0", port=7860)
 
4
  import gradio as gr
5
  import argilla as rg
6
  from argilla.webhooks import webhook_listener
7
+ from dataclasses import dataclass, field, asdict
8
+ from typing import Dict, List, Optional, Tuple, Any, Callable
9
+ import logging
10
 
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
 
 
14
 
15
+ # ============================================================================
16
+ # DATA MODELS - Clear definition of data structures
17
+ # ============================================================================
18
 
19
+ @dataclass
20
+ class CountryData:
21
+ """Data model for country information and annotation progress."""
22
+ name: str
23
+ target: int
24
+ count: int = 0
25
+ percent: int = 0
26
+
27
+ def update_progress(self, new_count: Optional[int] = None):
28
+ """Update the progress percentage based on count/target."""
29
+ if new_count is not None:
30
+ self.count = new_count
31
+ self.percent = min(100, int((self.count / self.target) * 100))
32
+ return self
33
 
34
+ @dataclass
35
+ class Event:
36
+ """Data model for events in the system."""
37
+ event_type: str
38
+ timestamp: str = ""
39
+ country: str = ""
40
+ count: int = 0
41
+ percent: int = 0
42
+ error: str = ""
43
 
44
+ @dataclass
45
+ class ApplicationState:
46
+ """Central state management for the application."""
47
+ countries: Dict[str, CountryData] = field(default_factory=dict)
48
+ events: Queue = field(default_factory=Queue)
49
+
50
+ def to_dict(self) -> Dict[str, Any]:
51
+ """Convert state to a serializable dictionary for the UI."""
52
+ return {
53
+ code: asdict(data) for code, data in self.countries.items()
54
+ }
55
+
56
+ def to_json(self) -> str:
57
+ """Convert state to JSON for the UI."""
58
+ return json.dumps(self.to_dict())
59
+
60
+ def add_event(self, event: Event):
61
+ """Add an event to the queue."""
62
+ self.events.put(asdict(event))
63
+
64
+ def get_next_event(self) -> Dict[str, Any]:
65
+ """Get the next event from the queue."""
66
+ if not self.events.empty():
67
+ return self.events.get()
68
+ return {}
69
+
70
+ def update_country_progress(self, country_code: str, count: Optional[int] = None) -> bool:
71
+ """Update a country's annotation progress."""
72
+ if country_code in self.countries:
73
+ if count is not None:
74
+ self.countries[country_code].count = count
75
+ self.countries[country_code].update_progress()
76
+
77
+ # Create and add a progress update event
78
+ self.add_event(Event(
79
+ event_type="progress_update",
80
+ country=self.countries[country_code].name,
81
+ count=self.countries[country_code].count,
82
+ percent=self.countries[country_code].percent
83
+ ))
84
+ return True
85
+ return False
86
+
87
+ def increment_country_progress(self, country_code: str) -> bool:
88
+ """Increment a country's annotation count by 1."""
89
+ if country_code in self.countries:
90
+ self.countries[country_code].count += 1
91
+ return self.update_country_progress(country_code)
92
+ return False
93
+
94
+ def get_stats(self) -> Tuple[int, float, int]:
95
+ """Calculate overall statistics."""
96
+ total = sum(data.count for data in self.countries.values())
97
+ percentages = [data.percent for data in self.countries.values()]
98
+ avg = sum(percentages) / len(percentages) if percentages else 0
99
+ countries_50_plus = sum(1 for p in percentages if p >= 50)
100
+
101
+ return total, avg, countries_50_plus
102
 
103
+ # ============================================================================
104
+ # CONFIGURATION - Separated from business logic
105
+ # ============================================================================
 
 
 
 
 
106
 
107
+ class Config:
108
+ """Configuration for the application."""
109
+ # Country mapping (ISO code to name and target)
110
+ COUNTRY_MAPPING = {
111
+ "MX": {"name": "Mexico", "target": 1000},
112
+ "AR": {"name": "Argentina", "target": 800},
113
+ "CO": {"name": "Colombia", "target": 700},
114
+ "CL": {"name": "Chile", "target": 600},
115
+ "PE": {"name": "Peru", "target": 600},
116
+ "ES": {"name": "Spain", "target": 1200},
117
+ "BR": {"name": "Brazil", "target": 1000},
118
+ "VE": {"name": "Venezuela", "target": 500},
119
+ "EC": {"name": "Ecuador", "target": 400},
120
+ "BO": {"name": "Bolivia", "target": 300},
121
+ "PY": {"name": "Paraguay", "target": 300},
122
+ "UY": {"name": "Uruguay", "target": 300},
123
+ "CR": {"name": "Costa Rica", "target": 250},
124
+ "PA": {"name": "Panama", "target": 250},
125
+ "DO": {"name": "Dominican Republic", "target": 300},
126
+ "GT": {"name": "Guatemala", "target": 250},
127
+ "HN": {"name": "Honduras", "target": 200},
128
+ "SV": {"name": "El Salvador", "target": 200},
129
+ "NI": {"name": "Nicaragua", "target": 200},
130
+ "CU": {"name": "Cuba", "target": 300}
131
+ }
132
+
133
+ @classmethod
134
+ def create_country_data(cls) -> Dict[str, CountryData]:
135
+ """Create CountryData objects from the mapping."""
136
+ return {
137
+ code: CountryData(
138
+ name=data["name"],
139
+ target=data["target"]
140
+ ) for code, data in cls.COUNTRY_MAPPING.items()
141
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # ============================================================================
144
+ # SERVICES - Business logic separated from presentation and data access
145
+ # ============================================================================
 
 
 
146
 
147
+ class ArgillaService:
148
+ """Service for interacting with Argilla."""
149
+ def __init__(self, api_url: Optional[str] = None, api_key: Optional[str] = None):
150
+ """Initialize the Argilla service."""
151
+ self.api_url = api_url or os.getenv("ARGILLA_API_URL")
152
+ self.api_key = api_key or os.getenv("ARGILLA_API_KEY")
153
+
154
+ self.client = rg.Argilla(
155
+ api_url=self.api_url,
156
+ api_key=self.api_key,
157
+ )
158
+ self.server = rg.get_webhook_server()
159
 
160
+ def get_server(self):
161
+ """Get the Argilla webhook server."""
162
+ return self.server
163
+
164
+ def get_client_base_url(self) -> str:
165
+ """Get the base URL of the Argilla client."""
166
+ return self.client.http_client.base_url if hasattr(self.client, 'http_client') else "Not connected"
167
 
168
+ class CountryMappingService:
169
+ """Service for mapping between dataset names and country codes."""
170
+ @staticmethod
171
+ def find_country_code_from_dataset(dataset_name: str) -> Optional[str]:
172
+ """
173
+ Try to extract a country code from a dataset name by matching
174
+ country names in the dataset name.
175
+ """
176
+ dataset_name_lower = dataset_name.lower()
177
+ for code, data in Config.COUNTRY_MAPPING.items():
178
+ country_name = data["name"].lower()
179
+ if country_name in dataset_name_lower:
180
+ return code
181
+ return None
182
 
183
+ # ============================================================================
184
+ # UI COMPONENTS - Presentation layer separated from business logic
185
+ # ============================================================================
186
+
187
+ class MapVisualization:
188
+ """Component for D3.js map visualization."""
189
+ @staticmethod
190
+ def create_map_html() -> str:
191
+ """Create the initial HTML container for the map."""
192
+ return """
193
+ <div id="map-container" style="width:100%; height:600px; position:relative; background-color:#111;">
194
+ <div style="display:flex; justify-content:center; align-items:center; height:100%; color:white; font-family:sans-serif;">
195
+ Loading map visualization...
196
+ </div>
197
+ </div>
198
+ <div id="tooltip" style="position:absolute; background-color:rgba(0,0,0,0.8); border-radius:5px; padding:8px; color:white; font-size:12px; pointer-events:none; opacity:0; transition:opacity 0.3s;"></div>
199
+ """
200
+
201
+ @staticmethod
202
+ def create_d3_script(progress_data: str) -> str:
203
+ """Create the D3.js script for rendering the map."""
204
+ return f"""
205
+ async () => {{
206
+ // Load D3.js modules
207
+ const script1 = document.createElement("script");
208
+ script1.src = "https://cdn.jsdelivr.net/npm/d3@7";
209
+ document.head.appendChild(script1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ // Wait for D3 to load
212
+ await new Promise(resolve => {{
213
+ script1.onload = resolve;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  }});
215
 
216
+ console.log("D3 loaded successfully");
217
+
218
+ // Load topojson
219
+ const script2 = document.createElement("script");
220
+ script2.src = "https://cdn.jsdelivr.net/npm/topojson@3";
221
+ document.head.appendChild(script2);
222
+
223
+ await new Promise(resolve => {{
224
+ script2.onload = resolve;
225
+ }});
226
+
227
+ console.log("TopoJSON loaded successfully");
228
+
229
+ // The progress data passed from Python
230
+ const progressData = {progress_data};
231
+
232
+ // Set up the SVG container
233
+ const mapContainer = document.getElementById('map-container');
234
+ mapContainer.innerHTML = ''; // Clear loading message
235
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  const width = mapContainer.clientWidth;
237
+ const height = 600;
238
 
239
+ const svg = d3.select("#map-container")
240
+ .append("svg")
241
  .attr("width", width)
242
+ .attr("height", height)
243
+ .attr("viewBox", `0 0 ${{width}} ${{height}}`)
244
+ .style("background-color", "#111");
245
+
246
+ // Define color scale
247
+ const colorScale = d3.scaleLinear()
248
+ .domain([0, 100])
249
+ .range(["#4a1942", "#f32b7b"]);
250
 
251
+ // Set up projection focused on Latin America and Spain
252
+ const projection = d3.geoMercator()
253
+ .center([-60, 0])
254
+ .scale(width / 5)
255
  .translate([width / 2, height / 2]);
256
+
257
+ const path = d3.geoPath().projection(projection);
258
+
259
+ // Tooltip setup
260
+ const tooltip = d3.select("#tooltip");
261
 
262
+ // Load the world GeoJSON data
263
+ const response = await fetch("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson");
264
+ const data = await response.json();
265
 
266
+ // Draw the map
267
+ svg.selectAll("path")
268
+ .data(data.features)
269
+ .enter()
270
+ .append("path")
271
+ .attr("d", path)
272
+ .attr("stroke", "#f32b7b")
273
+ .attr("stroke-width", 1)
274
+ .attr("fill", d => {{
275
+ // Get the ISO code from the properties
276
+ const iso = d.properties.iso_a2;
277
+
278
+ if (progressData[iso]) {{
279
+ return colorScale(progressData[iso].percent);
280
+ }}
281
+ return "#2d3748"; // Default gray for non-tracked countries
282
+ }})
283
+ .on("mouseover", function(event, d) {{
284
+ const iso = d.properties.iso_a2;
285
+
286
+ d3.select(this)
287
+ .attr("stroke", "#4a1942")
288
+ .attr("stroke-width", 2);
289
+
290
+ if (progressData[iso]) {{
291
+ tooltip.style("opacity", 1)
292
+ .style("left", (event.pageX + 15) + "px")
293
+ .style("top", (event.pageY + 15) + "px")
294
+ .html(`
295
+ <strong>${{progressData[iso].name}}</strong><br/>
296
+ Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/>
297
+ Completion: ${{progressData[iso].percent}}%
298
+ `);
299
+ }}
300
+ }})
301
+ .on("mousemove", function(event) {{
302
+ tooltip.style("left", (event.pageX + 15) + "px")
303
+ .style("top", (event.pageY + 15) + "px");
304
+ }})
305
+ .on("mouseout", function() {{
306
+ d3.select(this)
307
+ .attr("stroke", "#f32b7b")
308
+ .attr("stroke-width", 1);
309
+
310
+ tooltip.style("opacity", 0);
311
+ }});
312
+
313
+ // Add legend
314
  const legendWidth = Math.min(width - 40, 200);
315
+ const legendHeight = 15;
316
  const legendX = width - legendWidth - 20;
317
 
318
+ const legend = svg.append("g")
319
+ .attr("class", "legend")
320
  .attr("transform", `translate(${{legendX}}, 30)`);
321
+
322
+ // Create gradient for legend
323
+ const defs = svg.append("defs");
324
+ const gradient = defs.append("linearGradient")
325
+ .attr("id", "dataGradient")
326
+ .attr("x1", "0%")
327
+ .attr("y1", "0%")
328
+ .attr("x2", "100%")
329
+ .attr("y2", "0%");
330
+
331
+ gradient.append("stop")
332
+ .attr("offset", "0%")
333
+ .attr("stop-color", "#4a1942");
334
+
335
+ gradient.append("stop")
336
+ .attr("offset", "100%")
337
+ .attr("stop-color", "#f32b7b");
338
+
339
+ // Add legend title
340
+ legend.append("text")
341
+ .attr("x", legendWidth / 2)
342
+ .attr("y", -10)
343
+ .attr("text-anchor", "middle")
344
+ .attr("font-size", "12px")
345
+ .attr("fill", "#f1f5f9")
346
+ .text("Annotation Progress");
347
+
348
+ // Add legend rectangle
349
+ legend.append("rect")
350
+ .attr("width", legendWidth)
351
+ .attr("height", legendHeight)
352
+ .attr("rx", 2)
353
+ .attr("ry", 2)
354
+ .style("fill", "url(#dataGradient)");
355
+
356
+ // Add legend labels
357
+ legend.append("text")
358
+ .attr("x", 0)
359
+ .attr("y", legendHeight + 15)
360
+ .attr("text-anchor", "start")
361
+ .attr("font-size", "10px")
362
+ .attr("fill", "#94a3b8")
363
+ .text("0%");
364
+
365
+ legend.append("text")
366
+ .attr("x", legendWidth / 2)
367
+ .attr("y", legendHeight + 15)
368
+ .attr("text-anchor", "middle")
369
+ .attr("font-size", "10px")
370
+ .attr("fill", "#94a3b8")
371
+ .text("50%");
372
+
373
+ legend.append("text")
374
+ .attr("x", legendWidth)
375
+ .attr("y", legendHeight + 15)
376
+ .attr("text-anchor", "end")
377
+ .attr("font-size", "10px")
378
+ .attr("fill", "#94a3b8")
379
+ .text("100%");
380
+
381
+ // Handle window resize
382
+ globalThis.resizeMap = () => {{
383
+ const width = mapContainer.clientWidth;
384
+
385
+ // Update SVG dimensions
386
+ d3.select("svg")
387
+ .attr("width", width)
388
+ .attr("viewBox", `0 0 ${{width}} ${{height}}`);
389
+
390
+ // Update projection
391
+ projection.scale(width / 5)
392
+ .translate([width / 2, height / 2]);
393
+
394
+ // Update paths
395
+ d3.selectAll("path").attr("d", path);
396
+
397
+ // Update legend position
398
+ const legendWidth = Math.min(width - 40, 200);
399
+ const legendX = width - legendWidth - 20;
400
+
401
+ d3.select(".legend")
402
+ .attr("transform", `translate(${{legendX}}, 30)`);
403
+ }};
404
+
405
+ window.addEventListener('resize', globalThis.resizeMap);
406
+ }}
407
+ """
408
 
409
+ # ============================================================================
410
+ # APPLICATION FACTORY - Creates and configures the application
411
+ # ============================================================================
 
412
 
413
+ class ApplicationFactory:
414
+ """Factory for creating the application components."""
415
+ @classmethod
416
+ def create_app_state(cls) -> ApplicationState:
417
+ """Create and initialize the application state."""
418
+ state = ApplicationState(countries=Config.create_country_data())
419
+
420
+ # Initialize with some sample data
421
+ for code in ["MX", "AR", "CO", "ES"]:
422
+ sample_count = int(state.countries[code].target * 0.3)
423
+ state.update_country_progress(code, sample_count)
424
+
425
+ state.update_country_progress("BR", int(state.countries["BR"].target * 0.5))
426
+ state.update_country_progress("CL", int(state.countries["CL"].target * 0.7))
427
+
428
+ return state
429
 
430
+ @classmethod
431
+ def create_argilla_service(cls) -> ArgillaService:
432
+ """Create the Argilla service."""
433
+ return ArgillaService()
434
+
435
+ @classmethod
436
+ def create_webhook_handler(cls, app_state: ApplicationState) -> Callable:
437
+ """Create the webhook handler function."""
438
+ country_service = CountryMappingService()
439
+
440
+ @webhook_listener(events=["response.created"])
441
+ async def handle_response_created(response, type, timestamp):
442
+ try:
443
+ # Log the event
444
+ logger.info(f"Received webhook event: {type} at {timestamp}")
445
+
446
+ # Add basic event to the queue
447
+ app_state.add_event(Event(
448
+ event_type=type,
449
+ timestamp=str(timestamp)
450
+ ))
451
+
452
+ # Extract dataset name
453
+ record = response.record
454
+ dataset_name = record.dataset.name
455
+ logger.info(f"Processing response for dataset: {dataset_name}")
456
+
457
+ # Find country code from dataset name
458
+ country_code = country_service.find_country_code_from_dataset(dataset_name)
459
+
460
+ # Update country progress if found
461
+ if country_code:
462
+ success = app_state.increment_country_progress(country_code)
463
+ if success:
464
+ country_data = app_state.countries[country_code]
465
+ logger.info(
466
+ f"Updated progress for {country_data.name}: "
467
+ f"{country_data.count}/{country_data.target} ({country_data.percent}%)"
468
+ )
469
+
470
+ except Exception as e:
471
+ logger.error(f"Error in webhook handler: {e}", exc_info=True)
472
+ app_state.add_event(Event(
473
+ event_type="error",
474
+ error=str(e)
475
+ ))
476
+
477
+ return handle_response_created
478
 
479
+ @classmethod
480
+ def create_ui(cls, argilla_service: ArgillaService, app_state: ApplicationState):
481
+ """Create the Gradio UI."""
482
+ # Create and configure the Gradio interface
483
+ demo = gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple"))
484
+
485
+ with demo:
486
+ argilla_server = argilla_service.get_client_base_url()
487
 
488
+ with gr.Row():
489
+ gr.Markdown(f"""
490
+ # Latin America & Spain Annotation Progress Map
491
+
492
+ ### Connected to Argilla server: {argilla_server}
493
+
494
+ This dashboard visualizes annotation progress across Latin America and Spain.
495
+ """)
496
+
497
+ with gr.Row():
498
+ with gr.Column(scale=2):
499
+ # Map visualization - empty at first
500
+ map_html = gr.HTML(MapVisualization.create_map_html(), label="Annotation Progress Map")
501
+
502
+ # Hidden element to store map data
503
+ map_data = gr.JSON(value=app_state.to_json(), visible=False)
504
+
505
+ with gr.Column(scale=1):
506
+ # Recent events log
507
+ events_json = gr.JSON(label="Recent Events", value={})
508
+
509
+ # Overall statistics
510
+ total_docs, avg_completion, countries_over_50 = app_state.get_stats()
511
+ total_docs_ui = gr.Number(value=total_docs, label="Total Documents", interactive=False)
512
+ avg_completion_ui = gr.Number(value=avg_completion, label="Average Completion (%)", interactive=False)
513
+ countries_over_50_ui = gr.Number(value=countries_over_50, label="Countries Over 50%", interactive=False)
514
+
515
+ # Country details
516
+ country_selector = gr.Dropdown(
517
+ choices=[f"{data.name} ({code})" for code, data in app_state.countries.items()],
518
+ label="Select Country"
519
+ )
520
+ country_progress = gr.JSON(label="Country Progress", value={})
521
+
522
+ # Refresh button
523
+ refresh_btn = gr.Button("Refresh Map")
524
+
525
+ # UI interaction functions
526
+ def update_map():
527
+ return app_state.to_json()
528
+
529
+ def update_country_details(country_selection):
530
+ if not country_selection:
531
+ return {}
532
+
533
+ # Extract the country code from the selection (format: "Country Name (CODE)")
534
+ code = country_selection.split("(")[-1].replace(")", "").strip()
535
+
536
+ if code in app_state.countries:
537
+ return asdict(app_state.countries[code])
538
+ return {}
539
+
540
+ def update_events():
541
+ event = app_state.get_next_event()
542
+ stats = app_state.get_stats()
543
+
544
+ # If this is a progress update, update the map data
545
+ if event.get("event_type") == "progress_update":
546
+ # This will indirectly trigger a map refresh through the change event
547
+ return event, app_state.to_json(), stats[0], stats[1], stats[2]
548
+
549
+ return event, None, stats[0], stats[1], stats[2]
550
+
551
+ # Set up event handlers
552
+ refresh_btn.click(
553
+ fn=update_map,
554
+ inputs=None,
555
+ outputs=map_data
556
  )
557
+
558
+ country_selector.change(
559
+ fn=update_country_details,
560
+ inputs=[country_selector],
561
+ outputs=[country_progress]
562
+ )
563
+
564
+ # When map_data is updated, reload the D3 script
565
+ map_data.change(
566
+ fn=lambda data: (None, MapVisualization.create_d3_script(data)),
567
+ inputs=map_data,
568
+ outputs=[None, None],
569
+ _js=MapVisualization.create_d3_script(app_state.to_json())
570
+ )
571
+
572
+ # Use timer to check for new events and update stats
573
+ gr.Timer(1, active=True).tick(
574
+ update_events,
575
+ outputs=[events_json, map_data, total_docs_ui, avg_completion_ui, countries_over_50_ui]
576
+ )
577
+
578
+ # Initialize D3 on page load
579
+ demo.load(None, None, None, _js=MapVisualization.create_d3_script(app_state.to_json()))
580
 
581
+ return demo
582
+
583
+ # ============================================================================
584
+ # MAIN APPLICATION - Entry point and initialization
585
+ # ============================================================================
586
+
587
+ def create_application():
588
+ """Create and configure the complete application."""
589
+ # Create application components
590
+ app_state = ApplicationFactory.create_app_state()
591
+ argilla_service = ApplicationFactory.create_argilla_service()
 
 
 
 
592
 
593
+ # Create and register webhook handler
594
+ webhook_handler = ApplicationFactory.create_webhook_handler(app_state)
 
 
 
 
 
 
 
 
 
595
 
596
+ # Create the UI
597
+ demo = ApplicationFactory.create_ui(argilla_service, app_state)
 
 
 
 
598
 
599
+ # Mount the Gradio app to the FastAPI server
600
+ server = argilla_service.get_server()
601
+ gr.mount_gradio_app(server, demo, path="/")
 
 
 
 
 
 
 
 
 
 
602
 
603
+ return server
 
604
 
605
+ # Application entry point
 
 
 
 
 
 
 
 
 
606
  if __name__ == "__main__":
607
  import uvicorn
608
 
609
+ # Create the application
610
+ server = create_application()
 
 
 
 
 
 
 
 
611
 
612
  # Start the server
613
  uvicorn.run(server, host="0.0.0.0", port=7860)