Spaces:
Sleeping
Sleeping
import contextlib | |
from typing import Literal, Tuple, List | |
import httpx | |
import nbformat | |
from nbformat import NotebookNode, ValidationError | |
from nbconvert import HTMLExporter | |
from starlette.applications import Starlette | |
from starlette.exceptions import HTTPException | |
from starlette.responses import FileResponse, JSONResponse, HTMLResponse | |
from starlette.requests import Request | |
from starlette.routing import Route | |
from nbconvert.preprocessors import Preprocessor | |
import re | |
from traitlets.config import Config | |
from huggingface_hub import model_info, dataset_info | |
from huggingface_hub.utils import RepositoryNotFoundError | |
from functools import lru_cache | |
hub_id_regex = re.compile(r"[^\w]([a-zA-Z\d-]{3,32}\/[\w\-._]{3,64})[^\w/]") | |
# TODO possibly make async but might be tricky to call inside PreProcessor | |
def check_hub_item(hub_id_match): | |
with contextlib.suppress(RepositoryNotFoundError): | |
model_info(hub_id_match) | |
return hub_id_match, "model" | |
with contextlib.suppress(RepositoryNotFoundError): | |
dataset_info(hub_id_match) | |
return hub_id_match, "dataset" | |
# async def check_repo_exists(regex_hub_id_match: str) -> Optional[Tuple[str, str]]: | |
# r = await client.get(f"https://huggingface.co/api/models/{regex_hub_id_match}") | |
# if r.status_code == 200: | |
# return regex_hub_id_match, 'model' | |
# r = await client.get(f"https://huggingface.co/api/datasets/{regex_hub_id_match}") | |
# if r.status_code == 200: | |
# return regex_hub_id_match, 'dataset' | |
class HubIDCell(Preprocessor): | |
def preprocess_cell(self, cell, resources, index): | |
if cell["cell_type"] == "code": | |
resources.setdefault("dataset_matches", set()) | |
resources.setdefault("model_matches", set()) | |
if match := re.search(hub_id_regex, cell["source"]): | |
hub_id_match = match.groups(0)[0] | |
if ( | |
hub_id_match not in resources["model_matches"] | |
or resources["dataset_matches"] | |
): | |
if hub_check := check_hub_item(hub_id_match): | |
hub_id_match, repo_item_type = hub_check | |
if repo_item_type == "model": | |
resources["model_matches"].add(hub_id_match) | |
if repo_item_type == "dataset": | |
resources["dataset_matches"].add(hub_id_match) | |
return cell, resources | |
c = Config() | |
c.HTMLExporter.preprocessors = [HubIDCell] | |
client = httpx.AsyncClient() | |
html_exporter = HTMLExporter(config=c) | |
async def homepage(_): | |
return FileResponse("static/index.html") | |
async def healthz(_): | |
return JSONResponse({"success": True}) | |
def convert( | |
s: str, theme: Literal["light", "dark"], debug_info: str | |
) -> Tuple[str, List[str], List[str]]: | |
# Capture potential validation error: | |
try: | |
notebook_node: NotebookNode = nbformat.reads( | |
s, | |
as_version=nbformat.current_nbformat, | |
) | |
except nbformat.reader.NotJSONError: | |
print(400, f"Notebook is not JSON. {debug_info}") | |
raise HTTPException(400, "Notebook is not JSON.") | |
except ValidationError as e: | |
print( | |
400, | |
f"Notebook is invalid according to nbformat: {e}. {debug_info}", | |
) | |
raise HTTPException( | |
400, | |
f"Notebook is invalid according to nbformat: {e}.", | |
) | |
print(f"Input: nbformat v{notebook_node.nbformat}.{notebook_node.nbformat_minor}") | |
html_exporter.theme = theme | |
body, metadata = html_exporter.from_notebook_node(notebook_node) | |
metadata = dict(metadata) | |
model_matches = metadata["model_matches"] | |
dataset_matches = metadata["dataset_matches"] | |
# TODO(customize or simplify template?) | |
# TODO(also check source code for jupyter/nbviewer) | |
for model_match in model_matches: | |
print(f"updating {model_match}") | |
body = body.replace( | |
model_match, | |
f"""<a href="https://huggingface.co/{model_match}">{model_match} </a>""", | |
) | |
for dataset_match in dataset_matches: | |
body = body.replace( | |
dataset_match, | |
f"""<a href="https://huggingface.co/dataset/{dataset_match}">{dataset_match} </a>""", | |
) | |
return body, metadata["model_matches"], metadata["dataset_matches"] | |
async def convert_from_url(req: Request): | |
url = req.query_params.get("url") | |
theme = "dark" if req.query_params.get("theme") == "dark" else "light" | |
if not url: | |
raise HTTPException(400, "Param url is missing") | |
print("\n===", url) | |
r = await client.get( | |
url, | |
follow_redirects=True, | |
# httpx no follow redirect by default | |
) | |
if r.status_code != 200: | |
raise HTTPException( | |
400, f"Got an error {r.status_code} when fetching remote file" | |
) | |
# return HTMLResponse(content=convert(r.text, theme=theme, debug_info=f"url={url}")) | |
html_text, model_matches, dataset_matches = convert( | |
r.text, theme=theme, debug_info=f"url={url}" | |
) | |
# return HTMLResponse(content=html_text) | |
return JSONResponse( | |
content={ | |
"html": html_text, | |
"model_matches": list(model_matches), | |
"dataset_matches": list(dataset_matches), | |
} | |
) | |
async def convert_from_upload(req: Request): | |
theme = "dark" if req.query_params.get("theme") == "dark" else "light" | |
s = (await req.body()).decode("utf-8") | |
return HTMLResponse( | |
content=convert( | |
s, theme=theme, debug_info=f"upload_from={req.headers.get('user-agent')}" | |
) | |
) | |
app = Starlette( | |
debug=False, | |
routes=[ | |
Route("/", homepage), | |
Route("/healthz", healthz), | |
Route("/convert", convert_from_url), | |
Route("/upload", convert_from_upload, methods=["POST"]), | |
], | |
) | |