File size: 4,091 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import urllib.parse
from os import PathLike
from aiohttp import web
from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher
from server import PromptServer
from pathlib import Path

# 文件限制大小(MB)
max_size = 50
def suffix_limiter(self: web.StaticResource, request: web.Request):
    suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"}
    rel_url = request.match_info["filename"]
    try:
        filename = Path(rel_url)
        if filename.anchor:
            raise web.HTTPForbidden()
        filepath = self._directory.joinpath(filename).resolve()
        if filepath.exists() and filepath.suffix.lower() not in suffixes:
            raise web.HTTPForbidden(reason="File type is not allowed")
    finally:
        pass

def filesize_limiter(self: web.StaticResource, request: web.Request):
    rel_url = request.match_info["filename"]
    try:
        filename = Path(rel_url)
        filepath = self._directory.joinpath(filename).resolve()
        if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024:
            raise web.HTTPForbidden(reason="File size is too large")
    finally:
        pass
class LimitResource(web.StaticResource):
    limiters = []

    def push_limiter(self, limiter):
        self.limiters.append(limiter)

    async def _handle(self, request: web.Request) -> web.StreamResponse:
        try:
            for limiter in self.limiters:
                limiter(self, request)
        except (ValueError, FileNotFoundError) as error:
            raise web.HTTPNotFound() from error

        return await super()._handle(request)

    def __repr__(self) -> str:
        name = "'" + self.name + "'" if self.name is not None else ""
        return f'<LimitResource {name} {self._prefix} -> {self._directory!r}>'

class LimitRouter(web.StaticDef):
    def __repr__(self) -> str:
        info = []
        for name, value in sorted(self.kwargs.items()):
            info.append(f", {name}={value!r}")
        return f'<LimitRouter {self.prefix} -> {self.path}{"".join(info)}>'

    def register(self, router: UrlDispatcher) -> list[AbstractRoute]:
        # resource = router.add_static(self.prefix, self.path, **self.kwargs)
        def add_static(
            self: UrlDispatcher,
            prefix: str,
            path: PathLike,
            *,
            name=None,
            expect_handler=None,
            chunk_size: int = 256 * 1024,
            show_index: bool = False,
            follow_symlinks: bool = False,
            append_version: bool = False,
        ) -> web.AbstractResource:
            assert prefix.startswith("/")
            if prefix.endswith("/"):
                prefix = prefix[:-1]
            resource = LimitResource(
                prefix,
                path,
                name=name,
                expect_handler=expect_handler,
                chunk_size=chunk_size,
                show_index=show_index,
                follow_symlinks=follow_symlinks,
                append_version=append_version,
            )
            resource.push_limiter(suffix_limiter)
            resource.push_limiter(filesize_limiter)
            self.register_resource(resource)
            return resource
        resource = add_static(router, self.prefix, self.path, **self.kwargs)
        routes = resource.get_info().get("routes", {})
        return list(routes.values())

def path_to_url(path):
    if not path:
        return path
    path = path.replace("\\", "/")
    if not path.startswith("/"):
        path = "/" + path
    while path.startswith("//"):
        path = path[1:]
    path = path.replace("//", "/")
    return path

def add_static_resource(prefix, path,limit=False):
    app = PromptServer.instance.app
    prefix = path_to_url(prefix)
    prefix = urllib.parse.quote(prefix)
    prefix = path_to_url(prefix)
    if limit:
        route = LimitRouter(prefix, path, {"follow_symlinks": True})
    else:
        route = web.static(prefix, path, follow_symlinks=True)
    app.add_routes([route])