Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import re | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Annotated, List | |
| from cashews import NOT_NONE, cache | |
| from dotenv import load_dotenv | |
| from fastapi import BackgroundTasks, FastAPI, Header, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from httpx import AsyncClient | |
| from huggingface_hub import CommitScheduler, DatasetCard, HfApi, hf_hub_download, whoami | |
| from huggingface_hub.utils import disable_progress_bars, logging | |
| from huggingface_hub.utils._errors import HTTPError | |
| from langfuse.openai import AsyncOpenAI # OpenAI integration | |
| from pydantic import BaseModel, Field | |
| from starlette.responses import RedirectResponse | |
| from card_processing import parse_markdown, try_load_text, is_empty_template | |
| disable_progress_bars() | |
| load_dotenv() | |
| logger = logging.get_logger(__name__) | |
| Gb = 1073741824 | |
| cache.setup("disk://", size_limit=16 * Gb) # configure as in-memory cache | |
| VOTES_FILE = "data/votes.jsonl" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| hf_api = HfApi(token=HF_TOKEN) | |
| async_httpx_client = AsyncClient() | |
| scheduler = CommitScheduler( | |
| repo_id="davanstrien/summary-ratings", | |
| repo_type="dataset", | |
| folder_path="data", | |
| path_in_repo="data", | |
| every=5, | |
| token=HF_TOKEN, | |
| hf_api=hf_api, | |
| ) | |
| async def lifespan(app: FastAPI): | |
| logger.info("Running startup event") | |
| if not Path(VOTES_FILE).exists(): | |
| path = hf_hub_download( | |
| repo_id="davanstrien/summary-ratings", | |
| filename="data/votes.jsonl", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| local_dir=".", | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info(f"Downloaded votes.jsonl to {path}") | |
| else: | |
| logger.info("Votes file already exists") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| # Configure CORS | |
| # origins = [ | |
| # "https://huggingface.co/*", | |
| # # "chrome-extension://deckahggoiaphiebdipfbiinmaihfpbk", # Replace with your Chrome plugin ID | |
| # ] | |
| # # Configure CORS settings | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["https://huggingface.co/*"], # Update with your frontend URL | |
| # allow_credentials=True, | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| def save_vote(vote_entry): | |
| with scheduler.lock: | |
| with open(VOTES_FILE, "a") as file: | |
| date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| vote_entry["timestamp"] = date_time | |
| file.write(json.dumps(vote_entry) + "\n") | |
| logger.info(f"Vote saved: {vote_entry}") | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| class Vote(BaseModel): | |
| dataset: str | |
| description: str | |
| vote: int = Field(..., ge=-1, le=1) | |
| userID: str | |
| def validate_token(token: str = Header(None)) -> bool: | |
| try: | |
| whoami(token) | |
| return True | |
| except HTTPError: | |
| return False | |
| async def receive_vote( | |
| vote: Vote, | |
| Authorization: Annotated[str, Header()], | |
| background_tasks: BackgroundTasks, | |
| ): | |
| if not validate_token(Authorization): | |
| logger.error("Invalid token") | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| vote_entry = { | |
| "dataset": vote.dataset, | |
| "vote": vote.vote, | |
| "description": vote.description, | |
| "userID": vote.userID, | |
| } | |
| # Append the vote entry to the JSONL file | |
| background_tasks.add_task(save_vote, vote_entry) | |
| return JSONResponse(content={"message": "Vote submitted successfully"}) | |
| def format_prompt(card: str) -> str: | |
| return f""" | |
| Write a tl;dr summary of a dataset based on the dataset card. Focus on the most critical aspects of the dataset. | |
| The summary should aim to concisely describe the dataset. | |
| CARD: \n\n{card[:6000]} | |
| --- | |
| \n\nInstructions: | |
| If the card provides the necessary information, say what the dataset can be used for. | |
| You do not need to mention that the dataset is hosted or available on the Hugging Face Hub. | |
| Do not mention the license of the dataset. | |
| Do not mention the number of examples in the training or test split. | |
| Only mention size if there is extensive discussion of the scale of the dataset in the dataset card. | |
| Do not speculate on anything not explicitly mentioned in the dataset card. | |
| In general avoid references to the quality of the dataset i.e. don't use phrases like 'a high-quality dataset' in the summary. | |
| \n\nOne sentence summary:""" | |
| async def check_when_dataset_last_modified(dataset_id: str) -> datetime | None: | |
| try: | |
| response = await async_httpx_client.get( | |
| f"https://huggingface.co/api/datasets/{dataset_id}" | |
| ) | |
| if last_modified := response.json().get("lastModified"): | |
| return datetime.fromisoformat(last_modified) | |
| return None | |
| except Exception as e: | |
| logger.error(e) | |
| return None | |
| async def predict(card: str, dataset_id: str) -> str | None: | |
| try: | |
| prompt = format_prompt(card) | |
| client = AsyncOpenAI( | |
| base_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1/v1", | |
| api_key=HF_TOKEN, | |
| ) | |
| chat_completion = await client.chat.completions.create( | |
| model="tgi", | |
| messages=[ | |
| {"role": "user", "content": prompt}, | |
| ], | |
| stream=False, | |
| tags=["tldr-summaries"], | |
| ) | |
| return chat_completion.choices[0].message.content.strip() | |
| except Exception as e: | |
| logger.error(e) | |
| return None | |
| async def get_summary(dataset_id: str) -> str | None: | |
| """ | |
| Get a summary for a dataset based on the provided dataset ID. | |
| Args: | |
| dataset_id (str): The ID of the dataset to retrieve the summary for. | |
| Returns: | |
| str | None: The generated summary for the dataset, or None if no summary is available or an error occurs.""" | |
| try: | |
| # dataset_id = request.dataset_id | |
| card_text = await async_httpx_client.get( | |
| f"https://huggingface.co/datasets/{dataset_id}/raw/main/README.md" | |
| ) | |
| card_text = card_text.text | |
| card = DatasetCard(card_text) | |
| text = card.text | |
| parsed_text = parse_markdown(text) | |
| if is_empty_template(parsed_text): | |
| return None | |
| cache_key = f"predict:{dataset_id}" | |
| cached_data = await cache.get(cache_key) | |
| if cached_data is not None: | |
| cached_summary, cached_last_modified_time = cached_data | |
| # Get the current last modified time of the dataset | |
| current_last_modified_time = await check_when_dataset_last_modified( | |
| dataset_id | |
| ) | |
| if ( | |
| current_last_modified_time is None | |
| or cached_last_modified_time >= current_last_modified_time | |
| ): | |
| # Use the cached summary if the cached last modified time is greater than or equal to the current last modified time | |
| logger.info("Using cached summary") | |
| return cached_summary | |
| summary = await predict(parsed_text, dataset_id) | |
| current_last_modified_time = await check_when_dataset_last_modified(dataset_id) | |
| await cache.set(cache_key, (summary, current_last_modified_time)) | |
| return summary | |
| except Exception as e: | |
| logger.error(e) | |
| return None | |
| class SummariesRequest(BaseModel): | |
| dataset_ids: List[str] | |
| async def get_summaries(request: SummariesRequest) -> dict: | |
| """ | |
| Get summaries for a list of datasets based on the provided dataset IDs. | |
| Args: | |
| dataset_ids (List[str]): A list of dataset IDs to retrieve the summaries for. | |
| Returns: | |
| dict: A dictionary mapping dataset IDs to their corresponding summaries. | |
| """ | |
| dataset_ids = request.dataset_ids | |
| async def get_summary_wrapper(dataset_id): | |
| return dataset_id, await get_summary(dataset_id) | |
| summary_tasks = [get_summary_wrapper(dataset_id) for dataset_id in dataset_ids] | |
| summaries = dict(await asyncio.gather(*summary_tasks)) | |
| for dataset_id in dataset_ids: | |
| if summaries[dataset_id] is None: | |
| summaries[dataset_id] = "No summary available" | |
| return summaries | |