ouhenio's picture
Update app.py
9ba222b verified
raw
history blame
6.42 kB
import gradio as gr
import random
import json
import fastapi
from fastapi import FastAPI, Request
import os
import argilla as rg
from functools import lru_cache
import time
import asyncio
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware
# Initialize Argilla client
client = rg.Argilla(
api_url=os.getenv("ARGILLA_API_URL", ""),
api_key=os.getenv("ARGILLA_API_KEY", "")
)
countries = {
"Argentina": {
"iso": "ARG",
"emoji": "🇦🇷"
},
"Bolivia": {
"iso": "BOL",
"emoji": "🇧🇴"
},
"Chile": {
"iso": "CHL",
"emoji": "🇨🇱"
},
"Colombia": {
"iso": "COL",
"emoji": "🇨🇴"
},
"Costa Rica": {
"iso": "CRI",
"emoji": "🇨🇷"
},
"Cuba": {
"iso": "CUB",
"emoji": "🇨🇺"
},
"Ecuador": {
"iso": "ECU",
"emoji": "🇪🇨"
},
"El Salvador": {
"iso": "SLV",
"emoji": "🇸🇻"
},
"España": {
"iso": "ESP",
"emoji": "🇪🇸"
},
"Guatemala": {
"iso": "GTM",
"emoji": "🇬🇹"
},
"Honduras": {
"iso": "HND",
"emoji": "🇭🇳"
},
"México": {
"iso": "MEX",
"emoji": "🇲🇽"
},
"Nicaragua": {
"iso": "NIC",
"emoji": "🇳🇮"
},
"Panamá": {
"iso": "PAN",
"emoji": "🇵🇦"
},
"Paraguay": {
"iso": "PRY",
"emoji": "🇵🇾"
},
"Perú": {
"iso": "PER",
"emoji": "🇵🇪"
},
"Puerto Rico": {
"iso": "PRI",
"emoji": "🇵🇷"
},
"República Dominicana": {
"iso": "DOM",
"emoji": "🇩🇴"
},
"Uruguay": {
"iso": "URY",
"emoji": "🇺🇾"
},
"Venezuela": {
"iso": "VEN",
"emoji": "🇻🇪"
}
}
# Cache the results for 5 minutes (300 seconds)
# This significantly reduces load on the Argilla server
@lru_cache(maxsize=32)
def count_answers_per_space_cached(country: str, cache_buster: int):
"""
Cached version of count_answers_per_space
cache_buster is used to invalidate the cache when needed
"""
return count_answers_per_space(country)
def count_answers_per_space(country: str):
iso = countries[country]["iso"]
emoji = countries[country]["emoji"]
dataset_name = f"{emoji} {country} - {iso} - Responder"
try:
dataset = client.datasets(dataset_name)
# Optimization: Get all records with their responses in one call
records = list(dataset.records(with_responses=True))
# Count total questions
total_questions = len(records)
# Count answered questions
answered_questions = 0
total_answers = 0
answers_per_user = {}
for record in records:
responses = record.responses.get("text", [])
if responses:
answered_questions += 1
total_answers += len(responses)
# Count per user
for response in responses:
user_id = str(response.user_id)
answers_per_user[user_id] = answers_per_user.get(user_id, 0) + 1
percentage_complete = (answered_questions / total_questions * 100) if total_questions > 0 else 0
return {
"name": country,
"total_questions": total_questions,
"answered_questions": answered_questions,
"total_answers": total_answers,
"percent": round(percentage_complete, 2),
"documents": total_answers * 10
}
except Exception as e:
# If space doesn't exist, return zero values
print(f"No dataset found for {dataset_name}: {e}")
return {
"name": country,
"total_questions": 0,
"answered_questions": 0,
"total_answers": 0,
"percent": 0,
"documents": 0
}
# Create a FastAPI app
app = FastAPI()
# Add Gzip compression middleware to reduce transferred data size
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Global variable to track when data was last updated
last_update_time = time.time()
cached_html_content = None
# Route to serve the map visualization
@app.get("/d3-map")
async def serve_map(request: Request, refresh: bool = False):
global last_update_time, cached_html_content
current_time = time.time()
# Use cached content if available and not expired (5 minute cache)
if cached_html_content and current_time - last_update_time < 300 and not refresh:
return HTMLResponse(content=cached_html_content)
# Generate data for each country by querying Argilla
# Use async gathering to fetch all data in parallel
cache_buster = int(current_time) # Use current time to bust cache when refresh=True
country_data = {}
for country in countries.keys():
country_data[countries[country]["iso"]] = count_answers_per_space_cached(country, cache_buster)
# Convert to JSON for JavaScript
country_data_json = json.dumps(country_data)
# Replace the placeholder with actual data
with open('template.txt', 'r') as f:
html_template = f.read()
html_content = html_template.replace("COUNTRY_DATA_PLACEHOLDER", country_data_json)
# Update the cache
cached_html_content = html_content
last_update_time = current_time
return HTMLResponse(content=html_content)
# Create a simple Gradio interface with an iframe
def create_iframe(refresh=False):
# Add a random parameter to force reload if refresh is True
param = f"refresh={str(refresh).lower()}&t={random.randint(1, 10000)}"
return f'<iframe src="/d3-map?{param}" style="width:100%; height:650px; border:none;"></iframe>'
# Create the Gradio blocks
with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) as demo:
gr.Markdown("# Mapa anotación")
iframe_output = gr.HTML(create_iframe())
# Refresh button to generate new random data and force cache refresh
def refresh():
return create_iframe(refresh=True)
gr.Button("Actualizar Datos").click(fn=refresh, outputs=iframe_output)
# Mount the Gradio app to the FastAPI app
gr.mount_gradio_app(app, demo, path="/")
# Start the server
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)