import urllib.parse from os import PathLike from aiohttp import web from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher from server import PromptServer from pathlib import Path # 文件限制大小(MB) max_size = 50 def suffix_limiter(self: web.StaticResource, request: web.Request): suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"} rel_url = request.match_info["filename"] try: filename = Path(rel_url) if filename.anchor: raise web.HTTPForbidden() filepath = self._directory.joinpath(filename).resolve() if filepath.exists() and filepath.suffix.lower() not in suffixes: raise web.HTTPForbidden(reason="File type is not allowed") finally: pass def filesize_limiter(self: web.StaticResource, request: web.Request): rel_url = request.match_info["filename"] try: filename = Path(rel_url) filepath = self._directory.joinpath(filename).resolve() if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024: raise web.HTTPForbidden(reason="File size is too large") finally: pass class LimitResource(web.StaticResource): limiters = [] def push_limiter(self, limiter): self.limiters.append(limiter) async def _handle(self, request: web.Request) -> web.StreamResponse: try: for limiter in self.limiters: limiter(self, request) except (ValueError, FileNotFoundError) as error: raise web.HTTPNotFound() from error return await super()._handle(request) def __repr__(self) -> str: name = "'" + self.name + "'" if self.name is not None else "" return f' {self._directory!r}>' class LimitRouter(web.StaticDef): def __repr__(self) -> str: info = [] for name, value in sorted(self.kwargs.items()): info.append(f", {name}={value!r}") return f' {self.path}{"".join(info)}>' def register(self, router: UrlDispatcher) -> list[AbstractRoute]: # resource = router.add_static(self.prefix, self.path, **self.kwargs) def add_static( self: UrlDispatcher, prefix: str, path: PathLike, *, name=None, expect_handler=None, chunk_size: int = 256 * 1024, show_index: bool = False, follow_symlinks: bool = False, append_version: bool = False, ) -> web.AbstractResource: assert prefix.startswith("/") if prefix.endswith("/"): prefix = prefix[:-1] resource = LimitResource( prefix, path, name=name, expect_handler=expect_handler, chunk_size=chunk_size, show_index=show_index, follow_symlinks=follow_symlinks, append_version=append_version, ) resource.push_limiter(suffix_limiter) resource.push_limiter(filesize_limiter) self.register_resource(resource) return resource resource = add_static(router, self.prefix, self.path, **self.kwargs) routes = resource.get_info().get("routes", {}) return list(routes.values()) def path_to_url(path): if not path: return path path = path.replace("\\", "/") if not path.startswith("/"): path = "/" + path while path.startswith("//"): path = path[1:] path = path.replace("//", "/") return path def add_static_resource(prefix, path,limit=False): app = PromptServer.instance.app prefix = path_to_url(prefix) prefix = urllib.parse.quote(prefix) prefix = path_to_url(prefix) if limit: route = LimitRouter(prefix, path, {"follow_symlinks": True}) else: route = web.static(prefix, path, follow_symlinks=True) app.add_routes([route])