controlnet
Browse files- app.py +2 -2
- app_init.py +26 -42
- frontend/package-lock.json +8 -0
- frontend/package.json +4 -1
- frontend/src/lib/components/Button.svelte +2 -1
- frontend/src/lib/components/Checkbox.svelte +10 -0
- frontend/src/lib/components/ImagePlayer.svelte +10 -4
- frontend/src/lib/components/InputRange.svelte +27 -6
- frontend/src/lib/components/PipelineOptions.svelte +12 -7
- frontend/src/lib/components/SeedInput.svelte +1 -1
- frontend/src/lib/components/VideoInput.svelte +70 -1
- frontend/src/lib/lcmLive.ts +99 -0
- frontend/src/lib/mediaStream.ts +93 -0
- frontend/src/lib/types.ts +2 -0
- frontend/src/lib/utils.ts +145 -0
- frontend/src/routes/+page.svelte +74 -13
- latent_consistency_controlnet.py +0 -1100
- pipelines/controlnet.py +183 -58
- pipelines/txt2img.py +14 -14
- canny_gpu.py → pipelines/utils/canny_gpu.py +0 -0
- requirements.txt +1 -1
- user_queue.py +19 -8
    	
        app.py
    CHANGED
    
    | @@ -3,7 +3,7 @@ from fastapi import FastAPI | |
| 3 | 
             
            from config import args
         | 
| 4 | 
             
            from device import device, torch_dtype
         | 
| 5 | 
             
            from app_init import init_app
         | 
| 6 | 
            -
            from user_queue import  | 
| 7 | 
             
            from util import get_pipeline_class
         | 
| 8 |  | 
| 9 |  | 
| @@ -11,4 +11,4 @@ app = FastAPI() | |
| 11 |  | 
| 12 | 
             
            pipeline_class = get_pipeline_class(args.pipeline)
         | 
| 13 | 
             
            pipeline = pipeline_class(args, device, torch_dtype)
         | 
| 14 | 
            -
            init_app(app,  | 
|  | |
| 3 | 
             
            from config import args
         | 
| 4 | 
             
            from device import device, torch_dtype
         | 
| 5 | 
             
            from app_init import init_app
         | 
| 6 | 
            +
            from user_queue import user_data_events
         | 
| 7 | 
             
            from util import get_pipeline_class
         | 
| 8 |  | 
| 9 |  | 
|  | |
| 11 |  | 
| 12 | 
             
            pipeline_class = get_pipeline_class(args.pipeline)
         | 
| 13 | 
             
            pipeline = pipeline_class(args, device, torch_dtype)
         | 
| 14 | 
            +
            init_app(app, user_data_events, args, pipeline)
         | 
    	
        app_init.py
    CHANGED
    
    | @@ -6,15 +6,16 @@ from fastapi.staticfiles import StaticFiles | |
| 6 | 
             
            import logging
         | 
| 7 | 
             
            import traceback
         | 
| 8 | 
             
            from config import Args
         | 
| 9 | 
            -
            from user_queue import  | 
| 10 | 
             
            import uuid
         | 
| 11 | 
            -
            import  | 
| 12 | 
             
            import time
         | 
| 13 | 
             
            from PIL import Image
         | 
| 14 | 
             
            import io
         | 
|  | |
| 15 |  | 
| 16 |  | 
| 17 | 
            -
            def init_app(app: FastAPI,  | 
| 18 | 
             
                app.add_middleware(
         | 
| 19 | 
             
                    CORSMiddleware,
         | 
| 20 | 
             
                    allow_origins=["*"],
         | 
| @@ -27,19 +28,20 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline): | |
| 27 | 
             
                @app.websocket("/ws")
         | 
| 28 | 
             
                async def websocket_endpoint(websocket: WebSocket):
         | 
| 29 | 
             
                    await websocket.accept()
         | 
| 30 | 
            -
                    if args.max_queue_size > 0 and len( | 
| 31 | 
             
                        print("Server is full")
         | 
| 32 | 
             
                        await websocket.send_json({"status": "error", "message": "Server is full"})
         | 
| 33 | 
             
                        await websocket.close()
         | 
| 34 | 
             
                        return
         | 
| 35 |  | 
| 36 | 
             
                    try:
         | 
| 37 | 
            -
                        uid = uuid.uuid4()
         | 
| 38 | 
             
                        print(f"New user connected: {uid}")
         | 
| 39 | 
             
                        await websocket.send_json(
         | 
| 40 | 
             
                            {"status": "success", "message": "Connected", "userId": uid}
         | 
| 41 | 
             
                        )
         | 
| 42 | 
            -
                         | 
|  | |
| 43 | 
             
                        await websocket.send_json(
         | 
| 44 | 
             
                            {"status": "start", "message": "Start Streaming", "userId": uid}
         | 
| 45 | 
             
                        )
         | 
| @@ -49,40 +51,27 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline): | |
| 49 | 
             
                        traceback.print_exc()
         | 
| 50 | 
             
                    finally:
         | 
| 51 | 
             
                        print(f"User disconnected: {uid}")
         | 
| 52 | 
            -
                         | 
| 53 | 
            -
                        queue = queue_value.get("queue", None)
         | 
| 54 | 
            -
                        if queue:
         | 
| 55 | 
            -
                            while not queue.empty():
         | 
| 56 | 
            -
                                try:
         | 
| 57 | 
            -
                                    queue.get_nowait()
         | 
| 58 | 
            -
                                except asyncio.QueueEmpty:
         | 
| 59 | 
            -
                                    continue
         | 
| 60 |  | 
| 61 | 
             
                @app.get("/queue_size")
         | 
| 62 | 
             
                async def get_queue_size():
         | 
| 63 | 
            -
                    queue_size = len( | 
| 64 | 
             
                    return JSONResponse({"queue_size": queue_size})
         | 
| 65 |  | 
| 66 | 
             
                @app.get("/stream/{user_id}")
         | 
| 67 | 
             
                async def stream(user_id: uuid.UUID):
         | 
| 68 | 
            -
                    uid = user_id
         | 
| 69 | 
             
                    try:
         | 
| 70 | 
            -
                        user_queue = user_queue_map[uid]
         | 
| 71 | 
            -
                        queue = user_queue["queue"]
         | 
| 72 |  | 
| 73 | 
             
                        async def generate():
         | 
| 74 | 
             
                            last_prompt: str = None
         | 
| 75 | 
             
                            while True:
         | 
| 76 | 
            -
                                data = await  | 
| 77 | 
            -
                                input_image = data["image"]
         | 
| 78 | 
             
                                params = data["params"]
         | 
| 79 | 
            -
                                 | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                                image = pipeline.predict(
         | 
| 83 | 
            -
                                    input_image,
         | 
| 84 | 
            -
                                    params,
         | 
| 85 | 
            -
                                )
         | 
| 86 | 
             
                                if image is None:
         | 
| 87 | 
             
                                    continue
         | 
| 88 | 
             
                                frame_data = io.BytesIO()
         | 
| @@ -91,36 +80,31 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline): | |
| 91 | 
             
                                if frame_data is not None and len(frame_data) > 0:
         | 
| 92 | 
             
                                    yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
         | 
| 93 |  | 
| 94 | 
            -
                                await  | 
| 95 |  | 
| 96 | 
             
                        return StreamingResponse(
         | 
| 97 | 
             
                            generate(), media_type="multipart/x-mixed-replace;boundary=frame"
         | 
| 98 | 
             
                        )
         | 
| 99 | 
             
                    except Exception as e:
         | 
| 100 | 
            -
                        logging.error(f"Streaming Error: {e}, { | 
| 101 | 
             
                        traceback.print_exc()
         | 
| 102 | 
             
                        return HTTPException(status_code=404, detail="User not found")
         | 
| 103 |  | 
| 104 | 
             
                async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
         | 
| 105 | 
            -
                    uid = user_id
         | 
| 106 | 
            -
                     | 
| 107 | 
            -
                    queue = user_queue["queue"]
         | 
| 108 | 
            -
                    if not queue:
         | 
| 109 | 
             
                        return HTTPException(status_code=404, detail="User not found")
         | 
| 110 | 
             
                    last_time = time.time()
         | 
| 111 | 
             
                    try:
         | 
| 112 | 
             
                        while True:
         | 
| 113 | 
            -
                            data = await websocket.receive_bytes()
         | 
| 114 | 
             
                            params = await websocket.receive_json()
         | 
| 115 | 
             
                            params = pipeline.InputParams(**params)
         | 
| 116 | 
            -
                             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
                                 | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
                                    continue
         | 
| 123 | 
            -
                            await queue.put({"image": pil_image, "params": params})
         | 
| 124 | 
             
                            if args.timeout > 0 and time.time() - last_time > args.timeout:
         | 
| 125 | 
             
                                await websocket.send_json(
         | 
| 126 | 
             
                                    {
         | 
|  | |
| 6 | 
             
            import logging
         | 
| 7 | 
             
            import traceback
         | 
| 8 | 
             
            from config import Args
         | 
| 9 | 
            +
            from user_queue import UserDataEventMap, UserDataEvent
         | 
| 10 | 
             
            import uuid
         | 
| 11 | 
            +
            from asyncio import Event, sleep
         | 
| 12 | 
             
            import time
         | 
| 13 | 
             
            from PIL import Image
         | 
| 14 | 
             
            import io
         | 
| 15 | 
            +
            from types import SimpleNamespace
         | 
| 16 |  | 
| 17 |  | 
| 18 | 
            +
            def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
         | 
| 19 | 
             
                app.add_middleware(
         | 
| 20 | 
             
                    CORSMiddleware,
         | 
| 21 | 
             
                    allow_origins=["*"],
         | 
|  | |
| 28 | 
             
                @app.websocket("/ws")
         | 
| 29 | 
             
                async def websocket_endpoint(websocket: WebSocket):
         | 
| 30 | 
             
                    await websocket.accept()
         | 
| 31 | 
            +
                    if args.max_queue_size > 0 and len(user_data_events) >= args.max_queue_size:
         | 
| 32 | 
             
                        print("Server is full")
         | 
| 33 | 
             
                        await websocket.send_json({"status": "error", "message": "Server is full"})
         | 
| 34 | 
             
                        await websocket.close()
         | 
| 35 | 
             
                        return
         | 
| 36 |  | 
| 37 | 
             
                    try:
         | 
| 38 | 
            +
                        uid = str(uuid.uuid4())
         | 
| 39 | 
             
                        print(f"New user connected: {uid}")
         | 
| 40 | 
             
                        await websocket.send_json(
         | 
| 41 | 
             
                            {"status": "success", "message": "Connected", "userId": uid}
         | 
| 42 | 
             
                        )
         | 
| 43 | 
            +
                        user_data_events[uid] = UserDataEvent()
         | 
| 44 | 
            +
                        print(f"User data events: {user_data_events}")
         | 
| 45 | 
             
                        await websocket.send_json(
         | 
| 46 | 
             
                            {"status": "start", "message": "Start Streaming", "userId": uid}
         | 
| 47 | 
             
                        )
         | 
|  | |
| 51 | 
             
                        traceback.print_exc()
         | 
| 52 | 
             
                    finally:
         | 
| 53 | 
             
                        print(f"User disconnected: {uid}")
         | 
| 54 | 
            +
                        del user_data_events[uid]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
                @app.get("/queue_size")
         | 
| 57 | 
             
                async def get_queue_size():
         | 
| 58 | 
            +
                    queue_size = len(user_data_events)
         | 
| 59 | 
             
                    return JSONResponse({"queue_size": queue_size})
         | 
| 60 |  | 
| 61 | 
             
                @app.get("/stream/{user_id}")
         | 
| 62 | 
             
                async def stream(user_id: uuid.UUID):
         | 
| 63 | 
            +
                    uid = str(user_id)
         | 
| 64 | 
             
                    try:
         | 
|  | |
|  | |
| 65 |  | 
| 66 | 
             
                        async def generate():
         | 
| 67 | 
             
                            last_prompt: str = None
         | 
| 68 | 
             
                            while True:
         | 
| 69 | 
            +
                                data = await user_data_events[uid].wait_for_data()
         | 
|  | |
| 70 | 
             
                                params = data["params"]
         | 
| 71 | 
            +
                                # input_image = data["image"]
         | 
| 72 | 
            +
                                # if input_image is None:
         | 
| 73 | 
            +
                                # continue
         | 
| 74 | 
            +
                                image = pipeline.predict(params)
         | 
|  | |
|  | |
|  | |
| 75 | 
             
                                if image is None:
         | 
| 76 | 
             
                                    continue
         | 
| 77 | 
             
                                frame_data = io.BytesIO()
         | 
|  | |
| 80 | 
             
                                if frame_data is not None and len(frame_data) > 0:
         | 
| 81 | 
             
                                    yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
         | 
| 82 |  | 
| 83 | 
            +
                                await sleep(1.0 / 120.0)
         | 
| 84 |  | 
| 85 | 
             
                        return StreamingResponse(
         | 
| 86 | 
             
                            generate(), media_type="multipart/x-mixed-replace;boundary=frame"
         | 
| 87 | 
             
                        )
         | 
| 88 | 
             
                    except Exception as e:
         | 
| 89 | 
            +
                        logging.error(f"Streaming Error: {e}, {user_data_events}")
         | 
| 90 | 
             
                        traceback.print_exc()
         | 
| 91 | 
             
                        return HTTPException(status_code=404, detail="User not found")
         | 
| 92 |  | 
| 93 | 
             
                async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
         | 
| 94 | 
            +
                    uid = str(user_id)
         | 
| 95 | 
            +
                    if uid not in user_data_events:
         | 
|  | |
|  | |
| 96 | 
             
                        return HTTPException(status_code=404, detail="User not found")
         | 
| 97 | 
             
                    last_time = time.time()
         | 
| 98 | 
             
                    try:
         | 
| 99 | 
             
                        while True:
         | 
|  | |
| 100 | 
             
                            params = await websocket.receive_json()
         | 
| 101 | 
             
                            params = pipeline.InputParams(**params)
         | 
| 102 | 
            +
                            params = SimpleNamespace(**params.dict())
         | 
| 103 | 
            +
                            if hasattr(params, "image"):
         | 
| 104 | 
            +
                                image_data = await websocket.receive_bytes()
         | 
| 105 | 
            +
                                pil_image = Image.open(io.BytesIO(image_data))
         | 
| 106 | 
            +
                                params.image = pil_image
         | 
| 107 | 
            +
                            user_data_events[uid].update_data({"params": params})
         | 
|  | |
|  | |
| 108 | 
             
                            if args.timeout > 0 and time.time() - last_time > args.timeout:
         | 
| 109 | 
             
                                await websocket.send_json(
         | 
| 110 | 
             
                                    {
         | 
    	
        frontend/package-lock.json
    CHANGED
    
    | @@ -7,6 +7,9 @@ | |
| 7 | 
             
                "": {
         | 
| 8 | 
             
                  "name": "frontend",
         | 
| 9 | 
             
                  "version": "0.0.1",
         | 
|  | |
|  | |
|  | |
| 10 | 
             
                  "devDependencies": {
         | 
| 11 | 
             
                    "@sveltejs/adapter-auto": "^2.0.0",
         | 
| 12 | 
             
                    "@sveltejs/adapter-static": "^2.0.3",
         | 
| @@ -3035,6 +3038,11 @@ | |
| 3035 | 
             
                    "queue-microtask": "^1.2.2"
         | 
| 3036 | 
             
                  }
         | 
| 3037 | 
             
                },
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 3038 | 
             
                "node_modules/sade": {
         | 
| 3039 | 
             
                  "version": "1.8.1",
         | 
| 3040 | 
             
                  "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
         | 
|  | |
| 7 | 
             
                "": {
         | 
| 8 | 
             
                  "name": "frontend",
         | 
| 9 | 
             
                  "version": "0.0.1",
         | 
| 10 | 
            +
                  "dependencies": {
         | 
| 11 | 
            +
                    "rvfc-polyfill": "^1.0.7"
         | 
| 12 | 
            +
                  },
         | 
| 13 | 
             
                  "devDependencies": {
         | 
| 14 | 
             
                    "@sveltejs/adapter-auto": "^2.0.0",
         | 
| 15 | 
             
                    "@sveltejs/adapter-static": "^2.0.3",
         | 
|  | |
| 3038 | 
             
                    "queue-microtask": "^1.2.2"
         | 
| 3039 | 
             
                  }
         | 
| 3040 | 
             
                },
         | 
| 3041 | 
            +
                "node_modules/rvfc-polyfill": {
         | 
| 3042 | 
            +
                  "version": "1.0.7",
         | 
| 3043 | 
            +
                  "resolved": "https://registry.npmjs.org/rvfc-polyfill/-/rvfc-polyfill-1.0.7.tgz",
         | 
| 3044 | 
            +
                  "integrity": "sha512-seBl7J1J3/k0LuzW2T9fG6JIOpni5AbU+/87LA+zTYKgTVhsfShmS8K/yOo1eeEjGJHnAdkVAUUM+PEjN9Mpkw=="
         | 
| 3045 | 
            +
                },
         | 
| 3046 | 
             
                "node_modules/sade": {
         | 
| 3047 | 
             
                  "version": "1.8.1",
         | 
| 3048 | 
             
                  "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
         | 
    	
        frontend/package.json
    CHANGED
    
    | @@ -33,5 +33,8 @@ | |
| 33 | 
             
                "typescript": "^5.0.0",
         | 
| 34 | 
             
                "vite": "^4.4.2"
         | 
| 35 | 
             
              },
         | 
| 36 | 
            -
              "type": "module"
         | 
|  | |
|  | |
|  | |
| 37 | 
             
            }
         | 
|  | |
| 33 | 
             
                "typescript": "^5.0.0",
         | 
| 34 | 
             
                "vite": "^4.4.2"
         | 
| 35 | 
             
              },
         | 
| 36 | 
            +
              "type": "module",
         | 
| 37 | 
            +
              "dependencies": {
         | 
| 38 | 
            +
                "rvfc-polyfill": "^1.0.7"
         | 
| 39 | 
            +
              }
         | 
| 40 | 
             
            }
         | 
    	
        frontend/src/lib/components/Button.svelte
    CHANGED
    
    | @@ -1,8 +1,9 @@ | |
| 1 | 
             
            <script lang="ts">
         | 
| 2 | 
             
              export let classList: string = '';
         | 
|  | |
| 3 | 
             
            </script>
         | 
| 4 |  | 
| 5 | 
            -
            <button class="button {classList}" on:click>
         | 
| 6 | 
             
              <slot />
         | 
| 7 | 
             
            </button>
         | 
| 8 |  | 
|  | |
| 1 | 
             
            <script lang="ts">
         | 
| 2 | 
             
              export let classList: string = '';
         | 
| 3 | 
            +
              export let disabled: boolean = false;
         | 
| 4 | 
             
            </script>
         | 
| 5 |  | 
| 6 | 
            +
            <button class="button {classList}" on:click {disabled}>
         | 
| 7 | 
             
              <slot />
         | 
| 8 | 
             
            </button>
         | 
| 9 |  | 
    	
        frontend/src/lib/components/Checkbox.svelte
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            <script lang="ts">
         | 
| 2 | 
            +
              import type { FieldProps } from '$lib/types';
         | 
| 3 | 
            +
              export let value = false;
         | 
| 4 | 
            +
              export let params: FieldProps;
         | 
| 5 | 
            +
            </script>
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            <div class="grid max-w-md grid-cols-4 items-center justify-items-start gap-3">
         | 
| 8 | 
            +
              <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
         | 
| 9 | 
            +
              <input bind:checked={value} type="checkbox" id={params.id} class="cursor-pointer" />
         | 
| 10 | 
            +
            </div>
         | 
    	
        frontend/src/lib/components/ImagePlayer.svelte
    CHANGED
    
    | @@ -1,12 +1,18 @@ | |
| 1 | 
             
            <script lang="ts">
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 | 
             
            </script>
         | 
| 3 |  | 
| 4 | 
             
            <div class="relative overflow-hidden rounded-lg border border-slate-300">
         | 
| 5 | 
             
              <!-- svelte-ignore a11y-missing-attribute -->
         | 
| 6 | 
            -
               | 
| 7 | 
            -
                class="aspect-square w-full rounded-lg"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
|  | |
| 10 | 
             
              <div class="absolute left-0 top-0 aspect-square w-1/4">
         | 
| 11 | 
             
                <div class="relative z-10 aspect-square w-full object-cover">
         | 
| 12 | 
             
                  <slot />
         | 
|  | |
| 1 | 
             
            <script lang="ts">
         | 
| 2 | 
            +
              import { isLCMRunning, lcmLiveState, lcmLiveActions } from '$lib/lcmLive';
         | 
| 3 | 
            +
              import { onFrameChangeStore } from '$lib/mediaStream';
         | 
| 4 | 
            +
              import { PUBLIC_BASE_URL } from '$env/static/public';
         | 
| 5 | 
            +
             | 
| 6 | 
            +
              $: streamId = $lcmLiveState.streamId;
         | 
| 7 | 
             
            </script>
         | 
| 8 |  | 
| 9 | 
             
            <div class="relative overflow-hidden rounded-lg border border-slate-300">
         | 
| 10 | 
             
              <!-- svelte-ignore a11y-missing-attribute -->
         | 
| 11 | 
            +
              {#if $isLCMRunning}
         | 
| 12 | 
            +
                <img class="aspect-square w-full rounded-lg" src={PUBLIC_BASE_URL + '/stream/' + streamId} />
         | 
| 13 | 
            +
              {:else}
         | 
| 14 | 
            +
                <div class="aspect-square w-full rounded-lg" />
         | 
| 15 | 
            +
              {/if}
         | 
| 16 | 
             
              <div class="absolute left-0 top-0 aspect-square w-1/4">
         | 
| 17 | 
             
                <div class="relative z-10 aspect-square w-full object-cover">
         | 
| 18 | 
             
                  <slot />
         | 
    	
        frontend/src/lib/components/InputRange.svelte
    CHANGED
    
    | @@ -8,14 +8,14 @@ | |
| 8 | 
             
              });
         | 
| 9 | 
             
            </script>
         | 
| 10 |  | 
| 11 | 
            -
            <div class="grid  | 
| 12 | 
            -
              <label class="text-sm font-medium" for= | 
| 13 | 
             
              <input
         | 
| 14 | 
            -
                class="col-span-2"
         | 
| 15 | 
             
                bind:value
         | 
| 16 | 
             
                type="range"
         | 
| 17 | 
            -
                id= | 
| 18 | 
            -
                name= | 
| 19 | 
             
                min={params?.min}
         | 
| 20 | 
             
                max={params?.max}
         | 
| 21 | 
             
                step={params?.step ?? 1}
         | 
| @@ -24,6 +24,27 @@ | |
| 24 | 
             
                type="number"
         | 
| 25 | 
             
                step={params?.step ?? 1}
         | 
| 26 | 
             
                bind:value
         | 
| 27 | 
            -
                class="rounded-md border  | 
| 28 | 
             
              />
         | 
| 29 | 
             
            </div>
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 8 | 
             
              });
         | 
| 9 | 
             
            </script>
         | 
| 10 |  | 
| 11 | 
            +
            <div class="grid grid-cols-4 items-center gap-3">
         | 
| 12 | 
            +
              <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
         | 
| 13 | 
             
              <input
         | 
| 14 | 
            +
                class="col-span-2 h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-300 dark:bg-gray-500"
         | 
| 15 | 
             
                bind:value
         | 
| 16 | 
             
                type="range"
         | 
| 17 | 
            +
                id={params.id}
         | 
| 18 | 
            +
                name={params.id}
         | 
| 19 | 
             
                min={params?.min}
         | 
| 20 | 
             
                max={params?.max}
         | 
| 21 | 
             
                step={params?.step ?? 1}
         | 
|  | |
| 24 | 
             
                type="number"
         | 
| 25 | 
             
                step={params?.step ?? 1}
         | 
| 26 | 
             
                bind:value
         | 
| 27 | 
            +
                class="rounded-md border px-1 py-1 text-center text-xs font-bold dark:text-black"
         | 
| 28 | 
             
              />
         | 
| 29 | 
             
            </div>
         | 
| 30 | 
            +
            <!-- 
         | 
| 31 | 
            +
            <style lang="postcss" scoped>
         | 
| 32 | 
            +
              input[type='range']::-webkit-slider-runnable-track {
         | 
| 33 | 
            +
                @apply h-2 cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 34 | 
            +
              }
         | 
| 35 | 
            +
              input[type='range']::-webkit-slider-thumb {
         | 
| 36 | 
            +
                @apply cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 37 | 
            +
              }
         | 
| 38 | 
            +
              input[type='range']::-moz-range-track {
         | 
| 39 | 
            +
                @apply cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 40 | 
            +
              }
         | 
| 41 | 
            +
              input[type='range']::-moz-range-thumb {
         | 
| 42 | 
            +
                @apply cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 43 | 
            +
              }
         | 
| 44 | 
            +
              input[type='range']::-ms-track {
         | 
| 45 | 
            +
                @apply cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 46 | 
            +
              }
         | 
| 47 | 
            +
              input[type='range']::-ms-thumb {
         | 
| 48 | 
            +
                @apply cursor-pointer rounded-lg dark:bg-gray-50;
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            </style> -->
         | 
    	
        frontend/src/lib/components/PipelineOptions.svelte
    CHANGED
    
    | @@ -5,6 +5,7 @@ | |
| 5 | 
             
              import InputRange from './InputRange.svelte';
         | 
| 6 | 
             
              import SeedInput from './SeedInput.svelte';
         | 
| 7 | 
             
              import TextArea from './TextArea.svelte';
         | 
|  | |
| 8 |  | 
| 9 | 
             
              export let pipelineParams: FieldProps[];
         | 
| 10 | 
             
              export let pipelineValues = {} as any;
         | 
| @@ -17,11 +18,13 @@ | |
| 17 | 
             
              {#if featuredOptions}
         | 
| 18 | 
             
                {#each featuredOptions as params}
         | 
| 19 | 
             
                  {#if params.field === FieldType.range}
         | 
| 20 | 
            -
                    <InputRange {params} bind:value={pipelineValues[params. | 
| 21 | 
             
                  {:else if params.field === FieldType.seed}
         | 
| 22 | 
            -
                    <SeedInput bind:value={pipelineValues[params. | 
| 23 | 
             
                  {:else if params.field === FieldType.textarea}
         | 
| 24 | 
            -
                    <TextArea {params} bind:value={pipelineValues[params. | 
|  | |
|  | |
| 25 | 
             
                  {/if}
         | 
| 26 | 
             
                {/each}
         | 
| 27 | 
             
              {/if}
         | 
| @@ -29,15 +32,17 @@ | |
| 29 |  | 
| 30 | 
             
            <details open>
         | 
| 31 | 
             
              <summary class="cursor-pointer font-medium">Advanced Options</summary>
         | 
| 32 | 
            -
              <div class=" | 
| 33 | 
             
                {#if advanceOptions}
         | 
| 34 | 
             
                  {#each advanceOptions as params}
         | 
| 35 | 
             
                    {#if params.field === FieldType.range}
         | 
| 36 | 
            -
                      <InputRange {params} bind:value={pipelineValues[params. | 
| 37 | 
             
                    {:else if params.field === FieldType.seed}
         | 
| 38 | 
            -
                      <SeedInput bind:value={pipelineValues[params. | 
| 39 | 
             
                    {:else if params.field === FieldType.textarea}
         | 
| 40 | 
            -
                      <TextArea {params} bind:value={pipelineValues[params. | 
|  | |
|  | |
| 41 | 
             
                    {/if}
         | 
| 42 | 
             
                  {/each}
         | 
| 43 | 
             
                {/if}
         | 
|  | |
| 5 | 
             
              import InputRange from './InputRange.svelte';
         | 
| 6 | 
             
              import SeedInput from './SeedInput.svelte';
         | 
| 7 | 
             
              import TextArea from './TextArea.svelte';
         | 
| 8 | 
            +
              import Checkbox from './Checkbox.svelte';
         | 
| 9 |  | 
| 10 | 
             
              export let pipelineParams: FieldProps[];
         | 
| 11 | 
             
              export let pipelineValues = {} as any;
         | 
|  | |
| 18 | 
             
              {#if featuredOptions}
         | 
| 19 | 
             
                {#each featuredOptions as params}
         | 
| 20 | 
             
                  {#if params.field === FieldType.range}
         | 
| 21 | 
            +
                    <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
         | 
| 22 | 
             
                  {:else if params.field === FieldType.seed}
         | 
| 23 | 
            +
                    <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
         | 
| 24 | 
             
                  {:else if params.field === FieldType.textarea}
         | 
| 25 | 
            +
                    <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
         | 
| 26 | 
            +
                  {:else if params.field === FieldType.checkbox}
         | 
| 27 | 
            +
                    <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
         | 
| 28 | 
             
                  {/if}
         | 
| 29 | 
             
                {/each}
         | 
| 30 | 
             
              {/if}
         | 
|  | |
| 32 |  | 
| 33 | 
             
            <details open>
         | 
| 34 | 
             
              <summary class="cursor-pointer font-medium">Advanced Options</summary>
         | 
| 35 | 
            +
              <div class="grid grid-cols-1 items-center gap-3 sm:grid-cols-2">
         | 
| 36 | 
             
                {#if advanceOptions}
         | 
| 37 | 
             
                  {#each advanceOptions as params}
         | 
| 38 | 
             
                    {#if params.field === FieldType.range}
         | 
| 39 | 
            +
                      <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
         | 
| 40 | 
             
                    {:else if params.field === FieldType.seed}
         | 
| 41 | 
            +
                      <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
         | 
| 42 | 
             
                    {:else if params.field === FieldType.textarea}
         | 
| 43 | 
            +
                      <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
         | 
| 44 | 
            +
                    {:else if params.field === FieldType.checkbox}
         | 
| 45 | 
            +
                      <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
         | 
| 46 | 
             
                    {/if}
         | 
| 47 | 
             
                  {/each}
         | 
| 48 | 
             
                {/if}
         | 
    	
        frontend/src/lib/components/SeedInput.svelte
    CHANGED
    
    | @@ -16,5 +16,5 @@ | |
| 16 | 
             
                name="seed"
         | 
| 17 | 
             
                class="col-span-2 rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
         | 
| 18 | 
             
              />
         | 
| 19 | 
            -
              <Button on:click={randomize}> | 
| 20 | 
             
            </div>
         | 
|  | |
| 16 | 
             
                name="seed"
         | 
| 17 | 
             
                class="col-span-2 rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
         | 
| 18 | 
             
              />
         | 
| 19 | 
            +
              <Button on:click={randomize}>Rand</Button>
         | 
| 20 | 
             
            </div>
         | 
    	
        frontend/src/lib/components/VideoInput.svelte
    CHANGED
    
    | @@ -1,4 +1,73 @@ | |
| 1 | 
             
            <script lang="ts">
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 | 
             
            </script>
         | 
| 3 |  | 
| 4 | 
            -
            <video | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            <script lang="ts">
         | 
| 2 | 
            +
              import 'rvfc-polyfill';
         | 
| 3 | 
            +
              import { onMount, onDestroy } from 'svelte';
         | 
| 4 | 
            +
              import {
         | 
| 5 | 
            +
                mediaStreamState,
         | 
| 6 | 
            +
                mediaStreamActions,
         | 
| 7 | 
            +
                isMediaStreaming,
         | 
| 8 | 
            +
                MediaStreamStatus,
         | 
| 9 | 
            +
                onFrameChangeStore
         | 
| 10 | 
            +
              } from '$lib/mediaStream';
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              $: mediaStream = $mediaStreamState.mediaStream;
         | 
| 13 | 
            +
             | 
| 14 | 
            +
              let videoEl: HTMLVideoElement;
         | 
| 15 | 
            +
              let videoFrameCallbackId: number;
         | 
| 16 | 
            +
              const WIDTH = 512;
         | 
| 17 | 
            +
              const HEIGHT = 512;
         | 
| 18 | 
            +
             | 
| 19 | 
            +
              onDestroy(() => {
         | 
| 20 | 
            +
                if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
         | 
| 21 | 
            +
              });
         | 
| 22 | 
            +
             | 
| 23 | 
            +
              function srcObject(node: HTMLVideoElement, stream: MediaStream) {
         | 
| 24 | 
            +
                node.srcObject = stream;
         | 
| 25 | 
            +
                return {
         | 
| 26 | 
            +
                  update(newStream: MediaStream) {
         | 
| 27 | 
            +
                    if (node.srcObject != newStream) {
         | 
| 28 | 
            +
                      node.srcObject = newStream;
         | 
| 29 | 
            +
                    }
         | 
| 30 | 
            +
                  }
         | 
| 31 | 
            +
                };
         | 
| 32 | 
            +
              }
         | 
| 33 | 
            +
              async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
         | 
| 34 | 
            +
                const blob = await grapBlobImg();
         | 
| 35 | 
            +
                onFrameChangeStore.set({ now, metadata, blob });
         | 
| 36 | 
            +
                videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
         | 
| 37 | 
            +
              }
         | 
| 38 | 
            +
             | 
| 39 | 
            +
              $: if ($isMediaStreaming == MediaStreamStatus.CONNECTED) {
         | 
| 40 | 
            +
                videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
         | 
| 41 | 
            +
              }
         | 
| 42 | 
            +
              async function grapBlobImg() {
         | 
| 43 | 
            +
                const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
         | 
| 44 | 
            +
                const videoW = videoEl.videoWidth;
         | 
| 45 | 
            +
                const videoH = videoEl.videoHeight;
         | 
| 46 | 
            +
                const aspectRatio = WIDTH / HEIGHT;
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                const ctx = canvas.getContext('2d') as OffscreenCanvasRenderingContext2D;
         | 
| 49 | 
            +
                ctx.drawImage(
         | 
| 50 | 
            +
                  videoEl,
         | 
| 51 | 
            +
                  videoW / 2 - (videoH * aspectRatio) / 2,
         | 
| 52 | 
            +
                  0,
         | 
| 53 | 
            +
                  videoH * aspectRatio,
         | 
| 54 | 
            +
                  videoH,
         | 
| 55 | 
            +
                  0,
         | 
| 56 | 
            +
                  0,
         | 
| 57 | 
            +
                  WIDTH,
         | 
| 58 | 
            +
                  HEIGHT
         | 
| 59 | 
            +
                );
         | 
| 60 | 
            +
                const blob = await canvas.convertToBlob({ type: 'image/jpeg', quality: 1 });
         | 
| 61 | 
            +
                return blob;
         | 
| 62 | 
            +
              }
         | 
| 63 | 
             
            </script>
         | 
| 64 |  | 
| 65 | 
            +
            <video
         | 
| 66 | 
            +
              class="aspect-square w-full object-cover"
         | 
| 67 | 
            +
              bind:this={videoEl}
         | 
| 68 | 
            +
              playsinline
         | 
| 69 | 
            +
              autoplay
         | 
| 70 | 
            +
              muted
         | 
| 71 | 
            +
              loop
         | 
| 72 | 
            +
              use:srcObject={mediaStream}
         | 
| 73 | 
            +
            ></video>
         | 
    	
        frontend/src/lib/lcmLive.ts
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import { writable } from 'svelte/store';
         | 
| 2 | 
            +
            import { PUBLIC_BASE_URL, PUBLIC_WSS_URL } from '$env/static/public';
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            export const isStreaming = writable(false);
         | 
| 5 | 
            +
            export const isLCMRunning = writable(false);
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            export enum LCMLiveStatus {
         | 
| 9 | 
            +
                INIT = "init",
         | 
| 10 | 
            +
                CONNECTED = "connected",
         | 
| 11 | 
            +
                DISCONNECTED = "disconnected",
         | 
| 12 | 
            +
            }
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            interface lcmLive {
         | 
| 15 | 
            +
                streamId: string | null;
         | 
| 16 | 
            +
                status: LCMLiveStatus
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            const initialState: lcmLive = {
         | 
| 20 | 
            +
                streamId: null,
         | 
| 21 | 
            +
                status: LCMLiveStatus.INIT
         | 
| 22 | 
            +
            };
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            export const lcmLiveState = writable(initialState);
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            let websocket: WebSocket | null = null;
         | 
| 27 | 
            +
            export const lcmLiveActions = {
         | 
| 28 | 
            +
                async start() {
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    isLCMRunning.set(true);
         | 
| 31 | 
            +
                    try {
         | 
| 32 | 
            +
                        const websocketURL = PUBLIC_WSS_URL ? PUBLIC_WSS_URL : `${window.location.protocol === "https:" ? "wss" : "ws"
         | 
| 33 | 
            +
                            }:${window.location.host}/ws`;
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        websocket = new WebSocket(websocketURL);
         | 
| 36 | 
            +
                        websocket.onopen = () => {
         | 
| 37 | 
            +
                            console.log("Connected to websocket");
         | 
| 38 | 
            +
                        };
         | 
| 39 | 
            +
                        websocket.onclose = () => {
         | 
| 40 | 
            +
                            lcmLiveState.update((state) => ({
         | 
| 41 | 
            +
                                ...state,
         | 
| 42 | 
            +
                                status: LCMLiveStatus.DISCONNECTED
         | 
| 43 | 
            +
                            }));
         | 
| 44 | 
            +
                            console.log("Disconnected from websocket");
         | 
| 45 | 
            +
                            isLCMRunning.set(false);
         | 
| 46 | 
            +
                        };
         | 
| 47 | 
            +
                        websocket.onerror = (err) => {
         | 
| 48 | 
            +
                            console.error(err);
         | 
| 49 | 
            +
                        };
         | 
| 50 | 
            +
                        websocket.onmessage = (event) => {
         | 
| 51 | 
            +
                            const data = JSON.parse(event.data);
         | 
| 52 | 
            +
                            console.log("WS: ", data);
         | 
| 53 | 
            +
                            switch (data.status) {
         | 
| 54 | 
            +
                                case "success":
         | 
| 55 | 
            +
                                    break;
         | 
| 56 | 
            +
                                case "start":
         | 
| 57 | 
            +
                                    const streamId = data.userId;
         | 
| 58 | 
            +
                                    lcmLiveState.update((state) => ({
         | 
| 59 | 
            +
                                        ...state,
         | 
| 60 | 
            +
                                        status: LCMLiveStatus.CONNECTED,
         | 
| 61 | 
            +
                                        streamId: streamId,
         | 
| 62 | 
            +
                                    }));
         | 
| 63 | 
            +
                                    break;
         | 
| 64 | 
            +
                                case "timeout":
         | 
| 65 | 
            +
                                    console.log("timeout");
         | 
| 66 | 
            +
                                case "error":
         | 
| 67 | 
            +
                                    console.log(data.message);
         | 
| 68 | 
            +
                                    isLCMRunning.set(false);
         | 
| 69 | 
            +
                            }
         | 
| 70 | 
            +
                        };
         | 
| 71 | 
            +
                        lcmLiveState.update((state) => ({
         | 
| 72 | 
            +
                            ...state,
         | 
| 73 | 
            +
                        }));
         | 
| 74 | 
            +
                    } catch (err) {
         | 
| 75 | 
            +
                        console.error(err);
         | 
| 76 | 
            +
                        isLCMRunning.set(false);
         | 
| 77 | 
            +
                    }
         | 
| 78 | 
            +
                },
         | 
| 79 | 
            +
                send(data: Blob | { [key: string]: any }) {
         | 
| 80 | 
            +
                    if (websocket && websocket.readyState === WebSocket.OPEN) {
         | 
| 81 | 
            +
                        if (data instanceof Blob) {
         | 
| 82 | 
            +
                            websocket.send(data);
         | 
| 83 | 
            +
                        } else {
         | 
| 84 | 
            +
                            websocket.send(JSON.stringify(data));
         | 
| 85 | 
            +
                        }
         | 
| 86 | 
            +
                    } else {
         | 
| 87 | 
            +
                        console.log("WebSocket not connected");
         | 
| 88 | 
            +
                    }
         | 
| 89 | 
            +
                },
         | 
| 90 | 
            +
                async stop() {
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if (websocket) {
         | 
| 93 | 
            +
                        websocket.close();
         | 
| 94 | 
            +
                    }
         | 
| 95 | 
            +
                    websocket = null;
         | 
| 96 | 
            +
                    lcmLiveState.set({ status: LCMLiveStatus.DISCONNECTED, streamId: null });
         | 
| 97 | 
            +
                    isLCMRunning.set(false)
         | 
| 98 | 
            +
                },
         | 
| 99 | 
            +
            };
         | 
    	
        frontend/src/lib/mediaStream.ts
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import { writable, type Writable } from 'svelte/store';
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            export enum MediaStreamStatus {
         | 
| 4 | 
            +
                INIT = "init",
         | 
| 5 | 
            +
                CONNECTED = "connected",
         | 
| 6 | 
            +
                DISCONNECTED = "disconnected",
         | 
| 7 | 
            +
            }
         | 
| 8 | 
            +
            export const onFrameChangeStore: Writable<{ now: Number, metadata: VideoFrameCallbackMetadata, blob: Blob }> = writable();
         | 
| 9 | 
            +
            export const isMediaStreaming = writable(MediaStreamStatus.INIT);
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            interface mediaStream {
         | 
| 12 | 
            +
                mediaStream: MediaStream | null;
         | 
| 13 | 
            +
                status: MediaStreamStatus
         | 
| 14 | 
            +
                devices: MediaDeviceInfo[];
         | 
| 15 | 
            +
            }
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            const initialState: mediaStream = {
         | 
| 18 | 
            +
                mediaStream: null,
         | 
| 19 | 
            +
                status: MediaStreamStatus.INIT,
         | 
| 20 | 
            +
                devices: [],
         | 
| 21 | 
            +
            };
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            export const mediaStreamState = writable(initialState);
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            export const mediaStreamActions = {
         | 
| 26 | 
            +
                async enumerateDevices() {
         | 
| 27 | 
            +
                    console.log("Enumerating devices");
         | 
| 28 | 
            +
                    await navigator.mediaDevices.enumerateDevices()
         | 
| 29 | 
            +
                        .then(devices => {
         | 
| 30 | 
            +
                            const cameras = devices.filter(device => device.kind === 'videoinput');
         | 
| 31 | 
            +
                            console.log("Cameras: ", cameras);
         | 
| 32 | 
            +
                            mediaStreamState.update((state) => ({
         | 
| 33 | 
            +
                                ...state,
         | 
| 34 | 
            +
                                devices: cameras,
         | 
| 35 | 
            +
                            }));
         | 
| 36 | 
            +
                        })
         | 
| 37 | 
            +
                        .catch(err => {
         | 
| 38 | 
            +
                            console.error(err);
         | 
| 39 | 
            +
                        });
         | 
| 40 | 
            +
                },
         | 
| 41 | 
            +
                async start(mediaDevicedID?: string) {
         | 
| 42 | 
            +
                    const constraints = {
         | 
| 43 | 
            +
                        audio: false,
         | 
| 44 | 
            +
                        video: {
         | 
| 45 | 
            +
                            width: 1024, height: 1024, deviceId: mediaDevicedID
         | 
| 46 | 
            +
                        }
         | 
| 47 | 
            +
                    };
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    await navigator.mediaDevices
         | 
| 50 | 
            +
                        .getUserMedia(constraints)
         | 
| 51 | 
            +
                        .then((mediaStream) => {
         | 
| 52 | 
            +
                            mediaStreamState.update((state) => ({
         | 
| 53 | 
            +
                                ...state,
         | 
| 54 | 
            +
                                mediaStream: mediaStream,
         | 
| 55 | 
            +
                                status: MediaStreamStatus.CONNECTED,
         | 
| 56 | 
            +
                            }));
         | 
| 57 | 
            +
                            isMediaStreaming.set(MediaStreamStatus.CONNECTED);
         | 
| 58 | 
            +
                        })
         | 
| 59 | 
            +
                        .catch((err) => {
         | 
| 60 | 
            +
                            console.error(`${err.name}: ${err.message}`);
         | 
| 61 | 
            +
                            isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
         | 
| 62 | 
            +
                        });
         | 
| 63 | 
            +
                },
         | 
| 64 | 
            +
                async switchCamera(mediaDevicedID: string) {
         | 
| 65 | 
            +
                    const constraints = {
         | 
| 66 | 
            +
                        audio: false,
         | 
| 67 | 
            +
                        video: { width: 1024, height: 1024, deviceId: mediaDevicedID }
         | 
| 68 | 
            +
                    };
         | 
| 69 | 
            +
                    await navigator.mediaDevices
         | 
| 70 | 
            +
                        .getUserMedia(constraints)
         | 
| 71 | 
            +
                        .then((mediaStream) => {
         | 
| 72 | 
            +
                            mediaStreamState.update((state) => ({
         | 
| 73 | 
            +
                                ...state,
         | 
| 74 | 
            +
                                mediaStream: mediaStream,
         | 
| 75 | 
            +
                                status: MediaStreamStatus.CONNECTED,
         | 
| 76 | 
            +
                            }));
         | 
| 77 | 
            +
                        })
         | 
| 78 | 
            +
                        .catch((err) => {
         | 
| 79 | 
            +
                            console.error(`${err.name}: ${err.message}`);
         | 
| 80 | 
            +
                        });
         | 
| 81 | 
            +
                },
         | 
| 82 | 
            +
                async stop() {
         | 
| 83 | 
            +
                    navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
         | 
| 84 | 
            +
                        mediaStream.getTracks().forEach((track) => track.stop());
         | 
| 85 | 
            +
                    });
         | 
| 86 | 
            +
                    mediaStreamState.update((state) => ({
         | 
| 87 | 
            +
                        ...state,
         | 
| 88 | 
            +
                        mediaStream: null,
         | 
| 89 | 
            +
                        status: MediaStreamStatus.DISCONNECTED,
         | 
| 90 | 
            +
                    }));
         | 
| 91 | 
            +
                    isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
            };
         | 
    	
        frontend/src/lib/types.ts
    CHANGED
    
    | @@ -2,6 +2,7 @@ export const enum FieldType { | |
| 2 | 
             
                range = "range",
         | 
| 3 | 
             
                seed = "seed",
         | 
| 4 | 
             
                textarea = "textarea",
         | 
|  | |
| 5 | 
             
            }
         | 
| 6 |  | 
| 7 | 
             
            export interface FieldProps {
         | 
| @@ -13,6 +14,7 @@ export interface FieldProps { | |
| 13 | 
             
                step?: number;
         | 
| 14 | 
             
                disabled?: boolean;
         | 
| 15 | 
             
                hide?: boolean;
         | 
|  | |
| 16 | 
             
            }
         | 
| 17 | 
             
            export interface PipelineInfo {
         | 
| 18 | 
             
                name: string;
         | 
|  | |
| 2 | 
             
                range = "range",
         | 
| 3 | 
             
                seed = "seed",
         | 
| 4 | 
             
                textarea = "textarea",
         | 
| 5 | 
            +
                checkbox = "checkbox",
         | 
| 6 | 
             
            }
         | 
| 7 |  | 
| 8 | 
             
            export interface FieldProps {
         | 
|  | |
| 14 | 
             
                step?: number;
         | 
| 15 | 
             
                disabled?: boolean;
         | 
| 16 | 
             
                hide?: boolean;
         | 
| 17 | 
            +
                id: string;
         | 
| 18 | 
             
            }
         | 
| 19 | 
             
            export interface PipelineInfo {
         | 
| 20 | 
             
                name: string;
         | 
    	
        frontend/src/lib/utils.ts
    ADDED
    
    | @@ -0,0 +1,145 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            export function LCMLive(webcamVideo, liveImage) {
         | 
| 2 | 
            +
                let websocket: WebSocket;
         | 
| 3 | 
            +
             | 
| 4 | 
            +
                async function start() {
         | 
| 5 | 
            +
                    return new Promise((resolve, reject) => {
         | 
| 6 | 
            +
                        const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
         | 
| 7 | 
            +
                            }:${window.location.host}/ws`;
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                        const socket = new WebSocket(websocketURL);
         | 
| 10 | 
            +
                        socket.onopen = () => {
         | 
| 11 | 
            +
                            console.log("Connected to websocket");
         | 
| 12 | 
            +
                        };
         | 
| 13 | 
            +
                        socket.onclose = () => {
         | 
| 14 | 
            +
                            console.log("Disconnected from websocket");
         | 
| 15 | 
            +
                            stop();
         | 
| 16 | 
            +
                            resolve({ "status": "disconnected" });
         | 
| 17 | 
            +
                        };
         | 
| 18 | 
            +
                        socket.onerror = (err) => {
         | 
| 19 | 
            +
                            console.error(err);
         | 
| 20 | 
            +
                            reject(err);
         | 
| 21 | 
            +
                        };
         | 
| 22 | 
            +
                        socket.onmessage = (event) => {
         | 
| 23 | 
            +
                            const data = JSON.parse(event.data);
         | 
| 24 | 
            +
                            switch (data.status) {
         | 
| 25 | 
            +
                                case "success":
         | 
| 26 | 
            +
                                    break;
         | 
| 27 | 
            +
                                case "start":
         | 
| 28 | 
            +
                                    const userId = data.userId;
         | 
| 29 | 
            +
                                    initVideoStream(userId);
         | 
| 30 | 
            +
                                    break;
         | 
| 31 | 
            +
                                case "timeout":
         | 
| 32 | 
            +
                                    stop();
         | 
| 33 | 
            +
                                    resolve({ "status": "timeout" });
         | 
| 34 | 
            +
                                case "error":
         | 
| 35 | 
            +
                                    stop();
         | 
| 36 | 
            +
                                    reject(data.message);
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                            }
         | 
| 39 | 
            +
                        };
         | 
| 40 | 
            +
                        websocket = socket;
         | 
| 41 | 
            +
                    })
         | 
| 42 | 
            +
                }
         | 
| 43 | 
            +
                function switchCamera() {
         | 
| 44 | 
            +
                    const constraints = {
         | 
| 45 | 
            +
                        audio: false,
         | 
| 46 | 
            +
                        video: { width: 1024, height: 1024, deviceId: mediaDevices[webcamsEl.value].deviceId }
         | 
| 47 | 
            +
                    };
         | 
| 48 | 
            +
                    navigator.mediaDevices
         | 
| 49 | 
            +
                        .getUserMedia(constraints)
         | 
| 50 | 
            +
                        .then((mediaStream) => {
         | 
| 51 | 
            +
                            webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
         | 
| 52 | 
            +
                            webcamVideo.srcObject = mediaStream;
         | 
| 53 | 
            +
                            webcamVideo.onloadedmetadata = () => {
         | 
| 54 | 
            +
                                webcamVideo.play();
         | 
| 55 | 
            +
                                webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
         | 
| 56 | 
            +
                            };
         | 
| 57 | 
            +
                        })
         | 
| 58 | 
            +
                        .catch((err) => {
         | 
| 59 | 
            +
                            console.error(`${err.name}: ${err.message}`);
         | 
| 60 | 
            +
                        });
         | 
| 61 | 
            +
                }
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                async function videoTimeUpdateHandler() {
         | 
| 64 | 
            +
                    const dimension = getValue("input[name=dimension]:checked");
         | 
| 65 | 
            +
                    const [WIDTH, HEIGHT] = JSON.parse(dimension);
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
         | 
| 68 | 
            +
                    const videoW = webcamVideo.videoWidth;
         | 
| 69 | 
            +
                    const videoH = webcamVideo.videoHeight;
         | 
| 70 | 
            +
                    const aspectRatio = WIDTH / HEIGHT;
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    const ctx = canvas.getContext("2d");
         | 
| 73 | 
            +
                    ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
         | 
| 74 | 
            +
                    const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
         | 
| 75 | 
            +
                    websocket.send(blob);
         | 
| 76 | 
            +
                    websocket.send(JSON.stringify({
         | 
| 77 | 
            +
                        "seed": getValue("#seed"),
         | 
| 78 | 
            +
                        "prompt": getValue("#prompt"),
         | 
| 79 | 
            +
                        "guidance_scale": getValue("#guidance-scale"),
         | 
| 80 | 
            +
                        "strength": getValue("#strength"),
         | 
| 81 | 
            +
                        "steps": getValue("#steps"),
         | 
| 82 | 
            +
                        "lcm_steps": getValue("#lcm_steps"),
         | 
| 83 | 
            +
                        "width": WIDTH,
         | 
| 84 | 
            +
                        "height": HEIGHT,
         | 
| 85 | 
            +
                        "controlnet_scale": getValue("#controlnet_scale"),
         | 
| 86 | 
            +
                        "controlnet_start": getValue("#controlnet_start"),
         | 
| 87 | 
            +
                        "controlnet_end": getValue("#controlnet_end"),
         | 
| 88 | 
            +
                        "canny_low_threshold": getValue("#canny_low_threshold"),
         | 
| 89 | 
            +
                        "canny_high_threshold": getValue("#canny_high_threshold"),
         | 
| 90 | 
            +
                        "debug_canny": getValue("#debug_canny")
         | 
| 91 | 
            +
                    }));
         | 
| 92 | 
            +
                }
         | 
| 93 | 
            +
                let mediaDevices = [];
         | 
| 94 | 
            +
                async function initVideoStream(userId) {
         | 
| 95 | 
            +
                    liveImage.src = `/stream/${userId}`;
         | 
| 96 | 
            +
                    await navigator.mediaDevices.enumerateDevices()
         | 
| 97 | 
            +
                        .then(devices => {
         | 
| 98 | 
            +
                            const cameras = devices.filter(device => device.kind === 'videoinput');
         | 
| 99 | 
            +
                            mediaDevices = cameras;
         | 
| 100 | 
            +
                            webcamsEl.innerHTML = "";
         | 
| 101 | 
            +
                            cameras.forEach((camera, index) => {
         | 
| 102 | 
            +
                                const option = document.createElement("option");
         | 
| 103 | 
            +
                                option.value = index;
         | 
| 104 | 
            +
                                option.innerText = camera.label;
         | 
| 105 | 
            +
                                webcamsEl.appendChild(option);
         | 
| 106 | 
            +
                                option.selected = index === 0;
         | 
| 107 | 
            +
                            });
         | 
| 108 | 
            +
                            webcamsEl.addEventListener("change", switchCamera);
         | 
| 109 | 
            +
                        })
         | 
| 110 | 
            +
                        .catch(err => {
         | 
| 111 | 
            +
                            console.error(err);
         | 
| 112 | 
            +
                        });
         | 
| 113 | 
            +
                    const constraints = {
         | 
| 114 | 
            +
                        audio: false,
         | 
| 115 | 
            +
                        video: { width: 1024, height: 1024, deviceId: mediaDevices[0].deviceId }
         | 
| 116 | 
            +
                    };
         | 
| 117 | 
            +
                    navigator.mediaDevices
         | 
| 118 | 
            +
                        .getUserMedia(constraints)
         | 
| 119 | 
            +
                        .then((mediaStream) => {
         | 
| 120 | 
            +
                            webcamVideo.srcObject = mediaStream;
         | 
| 121 | 
            +
                            webcamVideo.onloadedmetadata = () => {
         | 
| 122 | 
            +
                                webcamVideo.play();
         | 
| 123 | 
            +
                                webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
         | 
| 124 | 
            +
                            };
         | 
| 125 | 
            +
                        })
         | 
| 126 | 
            +
                        .catch((err) => {
         | 
| 127 | 
            +
                            console.error(`${err.name}: ${err.message}`);
         | 
| 128 | 
            +
                        });
         | 
| 129 | 
            +
                }
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
                async function stop() {
         | 
| 133 | 
            +
                    websocket.close();
         | 
| 134 | 
            +
                    navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
         | 
| 135 | 
            +
                        mediaStream.getTracks().forEach((track) => track.stop());
         | 
| 136 | 
            +
                    });
         | 
| 137 | 
            +
                    webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
         | 
| 138 | 
            +
                    webcamsEl.removeEventListener("change", switchCamera);
         | 
| 139 | 
            +
                    webcamVideo.srcObject = null;
         | 
| 140 | 
            +
                }
         | 
| 141 | 
            +
                return {
         | 
| 142 | 
            +
                    start,
         | 
| 143 | 
            +
                    stop
         | 
| 144 | 
            +
                }
         | 
| 145 | 
            +
            }
         | 
    	
        frontend/src/routes/+page.svelte
    CHANGED
    
    | @@ -7,6 +7,13 @@ | |
| 7 | 
             
              import Button from '$lib/components/Button.svelte';
         | 
| 8 | 
             
              import PipelineOptions from '$lib/components/PipelineOptions.svelte';
         | 
| 9 | 
             
              import Spinner from '$lib/icons/spinner.svelte';
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 |  | 
| 11 | 
             
              let pipelineParams: FieldProps[];
         | 
| 12 | 
             
              let pipelineInfo: PipelineInfo;
         | 
| @@ -21,11 +28,58 @@ | |
| 21 | 
             
                pipelineParams = Object.values(settings.input_params.properties);
         | 
| 22 | 
             
                pipelineInfo = settings.info.properties;
         | 
| 23 | 
             
                pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
         | 
|  | |
| 24 | 
             
                console.log('SETTINGS', pipelineInfo);
         | 
| 25 | 
             
              }
         | 
| 26 |  | 
| 27 | 
            -
              $: {
         | 
| 28 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 | 
             
              }
         | 
| 30 | 
             
            </script>
         | 
| 31 |  | 
| @@ -58,19 +112,26 @@ | |
| 58 | 
             
                </p>
         | 
| 59 | 
             
              </article>
         | 
| 60 | 
             
              {#if pipelineParams}
         | 
| 61 | 
            -
                < | 
| 62 | 
            -
             | 
| 63 | 
            -
                   | 
| 64 | 
            -
                     | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
|  | |
|  | |
| 69 | 
             
                <PipelineOptions {pipelineParams} bind:pipelineValues></PipelineOptions>
         | 
| 70 | 
             
                <div class="flex gap-3">
         | 
| 71 | 
            -
                  <Button> | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 74 | 
             
                </div>
         | 
| 75 |  | 
| 76 | 
             
                <ImagePlayer>
         | 
|  | |
| 7 | 
             
              import Button from '$lib/components/Button.svelte';
         | 
| 8 | 
             
              import PipelineOptions from '$lib/components/PipelineOptions.svelte';
         | 
| 9 | 
             
              import Spinner from '$lib/icons/spinner.svelte';
         | 
| 10 | 
            +
              import { isLCMRunning, lcmLiveState, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
         | 
| 11 | 
            +
              import {
         | 
| 12 | 
            +
                mediaStreamState,
         | 
| 13 | 
            +
                mediaStreamActions,
         | 
| 14 | 
            +
                isMediaStreaming,
         | 
| 15 | 
            +
                onFrameChangeStore
         | 
| 16 | 
            +
              } from '$lib/mediaStream';
         | 
| 17 |  | 
| 18 | 
             
              let pipelineParams: FieldProps[];
         | 
| 19 | 
             
              let pipelineInfo: PipelineInfo;
         | 
|  | |
| 28 | 
             
                pipelineParams = Object.values(settings.input_params.properties);
         | 
| 29 | 
             
                pipelineInfo = settings.info.properties;
         | 
| 30 | 
             
                pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
         | 
| 31 | 
            +
                console.log('PARAMS', pipelineParams);
         | 
| 32 | 
             
                console.log('SETTINGS', pipelineInfo);
         | 
| 33 | 
             
              }
         | 
| 34 |  | 
| 35 | 
            +
              // $: {
         | 
| 36 | 
            +
              //   console.log('isLCMRunning', $isLCMRunning);
         | 
| 37 | 
            +
              // }
         | 
| 38 | 
            +
              // $: {
         | 
| 39 | 
            +
              //   console.log('lcmLiveState', $lcmLiveState);
         | 
| 40 | 
            +
              // }
         | 
| 41 | 
            +
              // $: {
         | 
| 42 | 
            +
              //   console.log('mediaStreamState', $mediaStreamState);
         | 
| 43 | 
            +
              // }
         | 
| 44 | 
            +
              // $: if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
         | 
| 45 | 
            +
              //   lcmLiveActions.send(pipelineValues);
         | 
| 46 | 
            +
              // }
         | 
| 47 | 
            +
              onFrameChangeStore.subscribe(async (frame) => {
         | 
| 48 | 
            +
                if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
         | 
| 49 | 
            +
                  lcmLiveActions.send(pipelineValues);
         | 
| 50 | 
            +
                  lcmLiveActions.send(frame.blob);
         | 
| 51 | 
            +
                }
         | 
| 52 | 
            +
              });
         | 
| 53 | 
            +
              let startBt: Button;
         | 
| 54 | 
            +
              let stopBt: Button;
         | 
| 55 | 
            +
              let snapShotBt: Button;
         | 
| 56 | 
            +
             | 
| 57 | 
            +
              async function toggleLcmLive() {
         | 
| 58 | 
            +
                if (!$isLCMRunning) {
         | 
| 59 | 
            +
                  await mediaStreamActions.enumerateDevices();
         | 
| 60 | 
            +
                  await mediaStreamActions.start();
         | 
| 61 | 
            +
                  lcmLiveActions.start();
         | 
| 62 | 
            +
                } else {
         | 
| 63 | 
            +
                  mediaStreamActions.stop();
         | 
| 64 | 
            +
                  lcmLiveActions.stop();
         | 
| 65 | 
            +
                }
         | 
| 66 | 
            +
              }
         | 
| 67 | 
            +
              async function startLcmLive() {
         | 
| 68 | 
            +
                try {
         | 
| 69 | 
            +
                  $isLCMRunning = true;
         | 
| 70 | 
            +
                  // const res = await lcmLive.start();
         | 
| 71 | 
            +
                  $isLCMRunning = false;
         | 
| 72 | 
            +
                  // if (res.status === "timeout")
         | 
| 73 | 
            +
                  // toggleMessage("success")
         | 
| 74 | 
            +
                } catch (err) {
         | 
| 75 | 
            +
                  console.log(err);
         | 
| 76 | 
            +
                  // toggleMessage("error")
         | 
| 77 | 
            +
                  $isLCMRunning = false;
         | 
| 78 | 
            +
                }
         | 
| 79 | 
            +
              }
         | 
| 80 | 
            +
              async function stopLcmLive() {
         | 
| 81 | 
            +
                // await lcmLive.stop();
         | 
| 82 | 
            +
                $isLCMRunning = false;
         | 
| 83 | 
             
              }
         | 
| 84 | 
             
            </script>
         | 
| 85 |  | 
|  | |
| 112 | 
             
                </p>
         | 
| 113 | 
             
              </article>
         | 
| 114 | 
             
              {#if pipelineParams}
         | 
| 115 | 
            +
                <header>
         | 
| 116 | 
            +
                  <h2 class="font-medium">Prompt</h2>
         | 
| 117 | 
            +
                  <p class="text-sm text-gray-500">
         | 
| 118 | 
            +
                    Change the prompt to generate different images, accepts <a
         | 
| 119 | 
            +
                      href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
         | 
| 120 | 
            +
                      target="_blank"
         | 
| 121 | 
            +
                      class="text-blue-500 underline hover:no-underline">Compel</a
         | 
| 122 | 
            +
                    > syntax.
         | 
| 123 | 
            +
                  </p>
         | 
| 124 | 
            +
                </header>
         | 
| 125 | 
             
                <PipelineOptions {pipelineParams} bind:pipelineValues></PipelineOptions>
         | 
| 126 | 
             
                <div class="flex gap-3">
         | 
| 127 | 
            +
                  <Button on:click={toggleLcmLive}>
         | 
| 128 | 
            +
                    {#if $isLCMRunning}
         | 
| 129 | 
            +
                      Stop
         | 
| 130 | 
            +
                    {:else}
         | 
| 131 | 
            +
                      Start
         | 
| 132 | 
            +
                    {/if}
         | 
| 133 | 
            +
                  </Button>
         | 
| 134 | 
            +
                  <Button disabled={$isLCMRunning} classList={'ml-auto'}>Snapshot</Button>
         | 
| 135 | 
             
                </div>
         | 
| 136 |  | 
| 137 | 
             
                <ImagePlayer>
         | 
    	
        latent_consistency_controlnet.py
    DELETED
    
    | @@ -1,1100 +0,0 @@ | |
| 1 | 
            -
            # from https://github.com/taabata/LCM_Inpaint_Outpaint_Comfy/blob/main/LCM/pipeline_cn.py
         | 
| 2 | 
            -
            # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
         | 
| 17 | 
            -
            # and https://github.com/hojonathanho/diffusion
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            import math
         | 
| 20 | 
            -
            from dataclasses import dataclass
         | 
| 21 | 
            -
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            import numpy as np
         | 
| 24 | 
            -
            import torch
         | 
| 25 | 
            -
            from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            from diffusers import (
         | 
| 28 | 
            -
                AutoencoderKL,
         | 
| 29 | 
            -
                ConfigMixin,
         | 
| 30 | 
            -
                DiffusionPipeline,
         | 
| 31 | 
            -
                SchedulerMixin,
         | 
| 32 | 
            -
                UNet2DConditionModel,
         | 
| 33 | 
            -
                ControlNetModel,
         | 
| 34 | 
            -
                logging,
         | 
| 35 | 
            -
            )
         | 
| 36 | 
            -
            from diffusers.configuration_utils import register_to_config
         | 
| 37 | 
            -
            from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
         | 
| 38 | 
            -
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 39 | 
            -
            from diffusers.pipelines.stable_diffusion.safety_checker import (
         | 
| 40 | 
            -
                StableDiffusionSafetyChecker,
         | 
| 41 | 
            -
            )
         | 
| 42 | 
            -
            from diffusers.utils import BaseOutput
         | 
| 43 | 
            -
             | 
| 44 | 
            -
            from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
         | 
| 45 | 
            -
            from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
         | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
            import PIL.Image
         | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
         | 
| 55 | 
            -
            def retrieve_latents(encoder_output, generator):
         | 
| 56 | 
            -
                if hasattr(encoder_output, "latent_dist"):
         | 
| 57 | 
            -
                    return encoder_output.latent_dist.sample(generator)
         | 
| 58 | 
            -
                elif hasattr(encoder_output, "latents"):
         | 
| 59 | 
            -
                    return encoder_output.latents
         | 
| 60 | 
            -
                else:
         | 
| 61 | 
            -
                    raise AttributeError("Could not access latents of provided encoder_output")
         | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
            class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
         | 
| 65 | 
            -
                _optional_components = ["scheduler"]
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                def __init__(
         | 
| 68 | 
            -
                    self,
         | 
| 69 | 
            -
                    vae: AutoencoderKL,
         | 
| 70 | 
            -
                    text_encoder: CLIPTextModel,
         | 
| 71 | 
            -
                    tokenizer: CLIPTokenizer,
         | 
| 72 | 
            -
                    controlnet: Union[
         | 
| 73 | 
            -
                        ControlNetModel,
         | 
| 74 | 
            -
                        List[ControlNetModel],
         | 
| 75 | 
            -
                        Tuple[ControlNetModel],
         | 
| 76 | 
            -
                        MultiControlNetModel,
         | 
| 77 | 
            -
                    ],
         | 
| 78 | 
            -
                    unet: UNet2DConditionModel,
         | 
| 79 | 
            -
                    scheduler: "LCMScheduler",
         | 
| 80 | 
            -
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 81 | 
            -
                    feature_extractor: CLIPImageProcessor,
         | 
| 82 | 
            -
                    requires_safety_checker: bool = True,
         | 
| 83 | 
            -
                ):
         | 
| 84 | 
            -
                    super().__init__()
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                    scheduler = (
         | 
| 87 | 
            -
                        scheduler
         | 
| 88 | 
            -
                        if scheduler is not None
         | 
| 89 | 
            -
                        else LCMScheduler_X(
         | 
| 90 | 
            -
                            beta_start=0.00085,
         | 
| 91 | 
            -
                            beta_end=0.0120,
         | 
| 92 | 
            -
                            beta_schedule="scaled_linear",
         | 
| 93 | 
            -
                            prediction_type="epsilon",
         | 
| 94 | 
            -
                        )
         | 
| 95 | 
            -
                    )
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                    self.register_modules(
         | 
| 98 | 
            -
                        vae=vae,
         | 
| 99 | 
            -
                        text_encoder=text_encoder,
         | 
| 100 | 
            -
                        tokenizer=tokenizer,
         | 
| 101 | 
            -
                        unet=unet,
         | 
| 102 | 
            -
                        controlnet=controlnet,
         | 
| 103 | 
            -
                        scheduler=scheduler,
         | 
| 104 | 
            -
                        safety_checker=safety_checker,
         | 
| 105 | 
            -
                        feature_extractor=feature_extractor,
         | 
| 106 | 
            -
                    )
         | 
| 107 | 
            -
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 108 | 
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 109 | 
            -
                    self.control_image_processor = VaeImageProcessor(
         | 
| 110 | 
            -
                        vae_scale_factor=self.vae_scale_factor,
         | 
| 111 | 
            -
                        do_convert_rgb=True,
         | 
| 112 | 
            -
                        do_normalize=False,
         | 
| 113 | 
            -
                    )
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                def _encode_prompt(
         | 
| 116 | 
            -
                    self,
         | 
| 117 | 
            -
                    prompt,
         | 
| 118 | 
            -
                    device,
         | 
| 119 | 
            -
                    num_images_per_prompt,
         | 
| 120 | 
            -
                    prompt_embeds: None,
         | 
| 121 | 
            -
                ):
         | 
| 122 | 
            -
                    r"""
         | 
| 123 | 
            -
                    Encodes the prompt into text encoder hidden states.
         | 
| 124 | 
            -
                    Args:
         | 
| 125 | 
            -
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 126 | 
            -
                            prompt to be encoded
         | 
| 127 | 
            -
                        device: (`torch.device`):
         | 
| 128 | 
            -
                            torch device
         | 
| 129 | 
            -
                        num_images_per_prompt (`int`):
         | 
| 130 | 
            -
                            number of images that should be generated per prompt
         | 
| 131 | 
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 132 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 133 | 
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 134 | 
            -
                    """
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 137 | 
            -
                        pass
         | 
| 138 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 139 | 
            -
                        len(prompt)
         | 
| 140 | 
            -
                    else:
         | 
| 141 | 
            -
                        prompt_embeds.shape[0]
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                    if prompt_embeds is None:
         | 
| 144 | 
            -
                        text_inputs = self.tokenizer(
         | 
| 145 | 
            -
                            prompt,
         | 
| 146 | 
            -
                            padding="max_length",
         | 
| 147 | 
            -
                            max_length=self.tokenizer.model_max_length,
         | 
| 148 | 
            -
                            truncation=True,
         | 
| 149 | 
            -
                            return_tensors="pt",
         | 
| 150 | 
            -
                        )
         | 
| 151 | 
            -
                        text_input_ids = text_inputs.input_ids
         | 
| 152 | 
            -
                        untruncated_ids = self.tokenizer(
         | 
| 153 | 
            -
                            prompt, padding="longest", return_tensors="pt"
         | 
| 154 | 
            -
                        ).input_ids
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[
         | 
| 157 | 
            -
                            -1
         | 
| 158 | 
            -
                        ] and not torch.equal(text_input_ids, untruncated_ids):
         | 
| 159 | 
            -
                            removed_text = self.tokenizer.batch_decode(
         | 
| 160 | 
            -
                                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
         | 
| 161 | 
            -
                            )
         | 
| 162 | 
            -
                            logger.warning(
         | 
| 163 | 
            -
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 164 | 
            -
                                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
         | 
| 165 | 
            -
                            )
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                        if (
         | 
| 168 | 
            -
                            hasattr(self.text_encoder.config, "use_attention_mask")
         | 
| 169 | 
            -
                            and self.text_encoder.config.use_attention_mask
         | 
| 170 | 
            -
                        ):
         | 
| 171 | 
            -
                            attention_mask = text_inputs.attention_mask.to(device)
         | 
| 172 | 
            -
                        else:
         | 
| 173 | 
            -
                            attention_mask = None
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                        prompt_embeds = self.text_encoder(
         | 
| 176 | 
            -
                            text_input_ids.to(device),
         | 
| 177 | 
            -
                            attention_mask=attention_mask,
         | 
| 178 | 
            -
                        )
         | 
| 179 | 
            -
                        prompt_embeds = prompt_embeds[0]
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                    if self.text_encoder is not None:
         | 
| 182 | 
            -
                        prompt_embeds_dtype = self.text_encoder.dtype
         | 
| 183 | 
            -
                    elif self.unet is not None:
         | 
| 184 | 
            -
                        prompt_embeds_dtype = self.unet.dtype
         | 
| 185 | 
            -
                    else:
         | 
| 186 | 
            -
                        prompt_embeds_dtype = prompt_embeds.dtype
         | 
| 187 | 
            -
             | 
| 188 | 
            -
                    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 191 | 
            -
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 192 | 
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 193 | 
            -
                    prompt_embeds = prompt_embeds.view(
         | 
| 194 | 
            -
                        bs_embed * num_images_per_prompt, seq_len, -1
         | 
| 195 | 
            -
                    )
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                    # Don't need to get uncond prompt embedding because of LCM Guided Distillation
         | 
| 198 | 
            -
                    return prompt_embeds
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                def run_safety_checker(self, image, device, dtype):
         | 
| 201 | 
            -
                    if self.safety_checker is None:
         | 
| 202 | 
            -
                        has_nsfw_concept = None
         | 
| 203 | 
            -
                    else:
         | 
| 204 | 
            -
                        if torch.is_tensor(image):
         | 
| 205 | 
            -
                            feature_extractor_input = self.image_processor.postprocess(
         | 
| 206 | 
            -
                                image, output_type="pil"
         | 
| 207 | 
            -
                            )
         | 
| 208 | 
            -
                        else:
         | 
| 209 | 
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         | 
| 210 | 
            -
                        safety_checker_input = self.feature_extractor(
         | 
| 211 | 
            -
                            feature_extractor_input, return_tensors="pt"
         | 
| 212 | 
            -
                        ).to(device)
         | 
| 213 | 
            -
                        image, has_nsfw_concept = self.safety_checker(
         | 
| 214 | 
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         | 
| 215 | 
            -
                        )
         | 
| 216 | 
            -
                    return image, has_nsfw_concept
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                def prepare_control_image(
         | 
| 219 | 
            -
                    self,
         | 
| 220 | 
            -
                    image,
         | 
| 221 | 
            -
                    width,
         | 
| 222 | 
            -
                    height,
         | 
| 223 | 
            -
                    batch_size,
         | 
| 224 | 
            -
                    num_images_per_prompt,
         | 
| 225 | 
            -
                    device,
         | 
| 226 | 
            -
                    dtype,
         | 
| 227 | 
            -
                    do_classifier_free_guidance=False,
         | 
| 228 | 
            -
                    guess_mode=False,
         | 
| 229 | 
            -
                ):
         | 
| 230 | 
            -
                    image = self.control_image_processor.preprocess(
         | 
| 231 | 
            -
                        image, height=height, width=width
         | 
| 232 | 
            -
                    ).to(dtype=dtype)
         | 
| 233 | 
            -
                    image_batch_size = image.shape[0]
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                    if image_batch_size == 1:
         | 
| 236 | 
            -
                        repeat_by = batch_size
         | 
| 237 | 
            -
                    else:
         | 
| 238 | 
            -
                        # image batch size is the same as prompt batch size
         | 
| 239 | 
            -
                        repeat_by = num_images_per_prompt
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                    image = image.repeat_interleave(repeat_by, dim=0)
         | 
| 242 | 
            -
             | 
| 243 | 
            -
                    image = image.to(device=device, dtype=dtype)
         | 
| 244 | 
            -
             | 
| 245 | 
            -
                    if do_classifier_free_guidance and not guess_mode:
         | 
| 246 | 
            -
                        image = torch.cat([image] * 2)
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                    return image
         | 
| 249 | 
            -
             | 
| 250 | 
            -
                def prepare_latents(
         | 
| 251 | 
            -
                    self,
         | 
| 252 | 
            -
                    image,
         | 
| 253 | 
            -
                    timestep,
         | 
| 254 | 
            -
                    batch_size,
         | 
| 255 | 
            -
                    num_channels_latents,
         | 
| 256 | 
            -
                    height,
         | 
| 257 | 
            -
                    width,
         | 
| 258 | 
            -
                    dtype,
         | 
| 259 | 
            -
                    device,
         | 
| 260 | 
            -
                    latents=None,
         | 
| 261 | 
            -
                    generator=None,
         | 
| 262 | 
            -
                ):
         | 
| 263 | 
            -
                    shape = (
         | 
| 264 | 
            -
                        batch_size,
         | 
| 265 | 
            -
                        num_channels_latents,
         | 
| 266 | 
            -
                        height // self.vae_scale_factor,
         | 
| 267 | 
            -
                        width // self.vae_scale_factor,
         | 
| 268 | 
            -
                    )
         | 
| 269 | 
            -
             | 
| 270 | 
            -
                    if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
         | 
| 271 | 
            -
                        raise ValueError(
         | 
| 272 | 
            -
                            f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
         | 
| 273 | 
            -
                        )
         | 
| 274 | 
            -
             | 
| 275 | 
            -
                    image = image.to(device=device, dtype=dtype)
         | 
| 276 | 
            -
             | 
| 277 | 
            -
                    # batch_size = batch_size * num_images_per_prompt
         | 
| 278 | 
            -
             | 
| 279 | 
            -
                    if image.shape[1] == 4:
         | 
| 280 | 
            -
                        init_latents = image
         | 
| 281 | 
            -
             | 
| 282 | 
            -
                    else:
         | 
| 283 | 
            -
                        if isinstance(generator, list) and len(generator) != batch_size:
         | 
| 284 | 
            -
                            raise ValueError(
         | 
| 285 | 
            -
                                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 286 | 
            -
                                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         | 
| 287 | 
            -
                            )
         | 
| 288 | 
            -
             | 
| 289 | 
            -
                        elif isinstance(generator, list):
         | 
| 290 | 
            -
                            init_latents = [
         | 
| 291 | 
            -
                                retrieve_latents(
         | 
| 292 | 
            -
                                    self.vae.encode(image[i : i + 1]), generator=generator[i]
         | 
| 293 | 
            -
                                )
         | 
| 294 | 
            -
                                for i in range(batch_size)
         | 
| 295 | 
            -
                            ]
         | 
| 296 | 
            -
                            init_latents = torch.cat(init_latents, dim=0)
         | 
| 297 | 
            -
                        else:
         | 
| 298 | 
            -
                            init_latents = retrieve_latents(
         | 
| 299 | 
            -
                                self.vae.encode(image), generator=generator
         | 
| 300 | 
            -
                            )
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                        init_latents = self.vae.config.scaling_factor * init_latents
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    if (
         | 
| 305 | 
            -
                        batch_size > init_latents.shape[0]
         | 
| 306 | 
            -
                        and batch_size % init_latents.shape[0] == 0
         | 
| 307 | 
            -
                    ):
         | 
| 308 | 
            -
                        # expand init_latents for batch_size
         | 
| 309 | 
            -
                        deprecation_message = (
         | 
| 310 | 
            -
                            f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
         | 
| 311 | 
            -
                            " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
         | 
| 312 | 
            -
                            " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
         | 
| 313 | 
            -
                            " your script to pass as many initial images as text prompts to suppress this warning."
         | 
| 314 | 
            -
                        )
         | 
| 315 | 
            -
                        # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 316 | 
            -
                        additional_image_per_prompt = batch_size // init_latents.shape[0]
         | 
| 317 | 
            -
                        init_latents = torch.cat(
         | 
| 318 | 
            -
                            [init_latents] * additional_image_per_prompt, dim=0
         | 
| 319 | 
            -
                        )
         | 
| 320 | 
            -
                    elif (
         | 
| 321 | 
            -
                        batch_size > init_latents.shape[0]
         | 
| 322 | 
            -
                        and batch_size % init_latents.shape[0] != 0
         | 
| 323 | 
            -
                    ):
         | 
| 324 | 
            -
                        raise ValueError(
         | 
| 325 | 
            -
                            f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
         | 
| 326 | 
            -
                        )
         | 
| 327 | 
            -
                    else:
         | 
| 328 | 
            -
                        init_latents = torch.cat([init_latents], dim=0)
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                    shape = init_latents.shape
         | 
| 331 | 
            -
                    noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 332 | 
            -
             | 
| 333 | 
            -
                    # get latents
         | 
| 334 | 
            -
                    init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
         | 
| 335 | 
            -
                    latents = init_latents
         | 
| 336 | 
            -
             | 
| 337 | 
            -
                    return latents
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                    if latents is None:
         | 
| 340 | 
            -
                        latents = torch.randn(shape, dtype=dtype).to(device)
         | 
| 341 | 
            -
                    else:
         | 
| 342 | 
            -
                        latents = latents.to(device)
         | 
| 343 | 
            -
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 344 | 
            -
                    latents = latents * self.scheduler.init_noise_sigma
         | 
| 345 | 
            -
                    return latents
         | 
| 346 | 
            -
             | 
| 347 | 
            -
                def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
         | 
| 348 | 
            -
                    """
         | 
| 349 | 
            -
                    see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         | 
| 350 | 
            -
                    Args:
         | 
| 351 | 
            -
                    timesteps: torch.Tensor: generate embedding vectors at these timesteps
         | 
| 352 | 
            -
                    embedding_dim: int: dimension of the embeddings to generate
         | 
| 353 | 
            -
                    dtype: data type of the generated embeddings
         | 
| 354 | 
            -
                    Returns:
         | 
| 355 | 
            -
                    embedding vectors with shape `(len(timesteps), embedding_dim)`
         | 
| 356 | 
            -
                    """
         | 
| 357 | 
            -
                    assert len(w.shape) == 1
         | 
| 358 | 
            -
                    w = w * 1000.0
         | 
| 359 | 
            -
             | 
| 360 | 
            -
                    half_dim = embedding_dim // 2
         | 
| 361 | 
            -
                    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         | 
| 362 | 
            -
                    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         | 
| 363 | 
            -
                    emb = w.to(dtype)[:, None] * emb[None, :]
         | 
| 364 | 
            -
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 365 | 
            -
                    if embedding_dim % 2 == 1:  # zero pad
         | 
| 366 | 
            -
                        emb = torch.nn.functional.pad(emb, (0, 1))
         | 
| 367 | 
            -
                    assert emb.shape == (w.shape[0], embedding_dim)
         | 
| 368 | 
            -
                    return emb
         | 
| 369 | 
            -
             | 
| 370 | 
            -
                def get_timesteps(self, num_inference_steps, strength, device):
         | 
| 371 | 
            -
                    # get the original timestep using init_timestep
         | 
| 372 | 
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         | 
| 373 | 
            -
             | 
| 374 | 
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         | 
| 375 | 
            -
                    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
         | 
| 376 | 
            -
             | 
| 377 | 
            -
                    return timesteps, num_inference_steps - t_start
         | 
| 378 | 
            -
             | 
| 379 | 
            -
                @torch.no_grad()
         | 
| 380 | 
            -
                def __call__(
         | 
| 381 | 
            -
                    self,
         | 
| 382 | 
            -
                    prompt: Union[str, List[str]] = None,
         | 
| 383 | 
            -
                    image: PipelineImageInput = None,
         | 
| 384 | 
            -
                    control_image: PipelineImageInput = None,
         | 
| 385 | 
            -
                    strength: float = 0.8,
         | 
| 386 | 
            -
                    height: Optional[int] = 768,
         | 
| 387 | 
            -
                    width: Optional[int] = 768,
         | 
| 388 | 
            -
                    guidance_scale: float = 7.5,
         | 
| 389 | 
            -
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 390 | 
            -
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 391 | 
            -
                    generator: Optional[torch.Generator] = None,
         | 
| 392 | 
            -
                    num_inference_steps: int = 4,
         | 
| 393 | 
            -
                    lcm_origin_steps: int = 50,
         | 
| 394 | 
            -
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 395 | 
            -
                    output_type: Optional[str] = "pil",
         | 
| 396 | 
            -
                    return_dict: bool = True,
         | 
| 397 | 
            -
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 398 | 
            -
                    controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
         | 
| 399 | 
            -
                    guess_mode: bool = True,
         | 
| 400 | 
            -
                    control_guidance_start: Union[float, List[float]] = 0.0,
         | 
| 401 | 
            -
                    control_guidance_end: Union[float, List[float]] = 1.0,
         | 
| 402 | 
            -
                ):
         | 
| 403 | 
            -
                    controlnet = (
         | 
| 404 | 
            -
                        self.controlnet._orig_mod
         | 
| 405 | 
            -
                        if is_compiled_module(self.controlnet)
         | 
| 406 | 
            -
                        else self.controlnet
         | 
| 407 | 
            -
                    )
         | 
| 408 | 
            -
                    # 0. Default height and width to unet
         | 
| 409 | 
            -
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 410 | 
            -
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 411 | 
            -
                    if not isinstance(control_guidance_start, list) and isinstance(
         | 
| 412 | 
            -
                        control_guidance_end, list
         | 
| 413 | 
            -
                    ):
         | 
| 414 | 
            -
                        control_guidance_start = len(control_guidance_end) * [
         | 
| 415 | 
            -
                            control_guidance_start
         | 
| 416 | 
            -
                        ]
         | 
| 417 | 
            -
                    elif not isinstance(control_guidance_end, list) and isinstance(
         | 
| 418 | 
            -
                        control_guidance_start, list
         | 
| 419 | 
            -
                    ):
         | 
| 420 | 
            -
                        control_guidance_end = len(control_guidance_start) * [control_guidance_end]
         | 
| 421 | 
            -
                    elif not isinstance(control_guidance_start, list) and not isinstance(
         | 
| 422 | 
            -
                        control_guidance_end, list
         | 
| 423 | 
            -
                    ):
         | 
| 424 | 
            -
                        mult = (
         | 
| 425 | 
            -
                            len(controlnet.nets)
         | 
| 426 | 
            -
                            if isinstance(controlnet, MultiControlNetModel)
         | 
| 427 | 
            -
                            else 1
         | 
| 428 | 
            -
                        )
         | 
| 429 | 
            -
                        control_guidance_start, control_guidance_end = mult * [
         | 
| 430 | 
            -
                            control_guidance_start
         | 
| 431 | 
            -
                        ], mult * [control_guidance_end]
         | 
| 432 | 
            -
                    # 2. Define call parameters
         | 
| 433 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 434 | 
            -
                        batch_size = 1
         | 
| 435 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 436 | 
            -
                        batch_size = len(prompt)
         | 
| 437 | 
            -
                    else:
         | 
| 438 | 
            -
                        batch_size = prompt_embeds.shape[0]
         | 
| 439 | 
            -
             | 
| 440 | 
            -
                    device = self._execution_device
         | 
| 441 | 
            -
                    # do_classifier_free_guidance = guidance_scale > 0.0  # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
         | 
| 442 | 
            -
                    global_pool_conditions = (
         | 
| 443 | 
            -
                        controlnet.config.global_pool_conditions
         | 
| 444 | 
            -
                        if isinstance(controlnet, ControlNetModel)
         | 
| 445 | 
            -
                        else controlnet.nets[0].config.global_pool_conditions
         | 
| 446 | 
            -
                    )
         | 
| 447 | 
            -
                    guess_mode = guess_mode or global_pool_conditions
         | 
| 448 | 
            -
                    # 3. Encode input prompt
         | 
| 449 | 
            -
                    prompt_embeds = self._encode_prompt(
         | 
| 450 | 
            -
                        prompt,
         | 
| 451 | 
            -
                        device,
         | 
| 452 | 
            -
                        num_images_per_prompt,
         | 
| 453 | 
            -
                        prompt_embeds=prompt_embeds,
         | 
| 454 | 
            -
                    )
         | 
| 455 | 
            -
             | 
| 456 | 
            -
                    # 3.5 encode image
         | 
| 457 | 
            -
                    image = self.image_processor.preprocess(image)
         | 
| 458 | 
            -
             | 
| 459 | 
            -
                    if isinstance(controlnet, ControlNetModel):
         | 
| 460 | 
            -
                        control_image = self.prepare_control_image(
         | 
| 461 | 
            -
                            image=control_image,
         | 
| 462 | 
            -
                            width=width,
         | 
| 463 | 
            -
                            height=height,
         | 
| 464 | 
            -
                            batch_size=batch_size * num_images_per_prompt,
         | 
| 465 | 
            -
                            num_images_per_prompt=num_images_per_prompt,
         | 
| 466 | 
            -
                            device=device,
         | 
| 467 | 
            -
                            dtype=controlnet.dtype,
         | 
| 468 | 
            -
                            guess_mode=guess_mode,
         | 
| 469 | 
            -
                        )
         | 
| 470 | 
            -
                    elif isinstance(controlnet, MultiControlNetModel):
         | 
| 471 | 
            -
                        control_images = []
         | 
| 472 | 
            -
             | 
| 473 | 
            -
                        for control_image_ in control_image:
         | 
| 474 | 
            -
                            control_image_ = self.prepare_control_image(
         | 
| 475 | 
            -
                                image=control_image_,
         | 
| 476 | 
            -
                                width=width,
         | 
| 477 | 
            -
                                height=height,
         | 
| 478 | 
            -
                                batch_size=batch_size * num_images_per_prompt,
         | 
| 479 | 
            -
                                num_images_per_prompt=num_images_per_prompt,
         | 
| 480 | 
            -
                                device=device,
         | 
| 481 | 
            -
                                dtype=controlnet.dtype,
         | 
| 482 | 
            -
                                do_classifier_free_guidance=do_classifier_free_guidance,
         | 
| 483 | 
            -
                                guess_mode=guess_mode,
         | 
| 484 | 
            -
                            )
         | 
| 485 | 
            -
             | 
| 486 | 
            -
                            control_images.append(control_image_)
         | 
| 487 | 
            -
             | 
| 488 | 
            -
                        control_image = control_images
         | 
| 489 | 
            -
                    else:
         | 
| 490 | 
            -
                        assert False
         | 
| 491 | 
            -
             | 
| 492 | 
            -
                    # 4. Prepare timesteps
         | 
| 493 | 
            -
                    self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
         | 
| 494 | 
            -
                    # timesteps = self.scheduler.timesteps
         | 
| 495 | 
            -
                    # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
         | 
| 496 | 
            -
                    timesteps = self.scheduler.timesteps
         | 
| 497 | 
            -
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         | 
| 498 | 
            -
             | 
| 499 | 
            -
                    print("timesteps: ", timesteps)
         | 
| 500 | 
            -
             | 
| 501 | 
            -
                    # 5. Prepare latent variable
         | 
| 502 | 
            -
                    num_channels_latents = self.unet.config.in_channels
         | 
| 503 | 
            -
                    latents = self.prepare_latents(
         | 
| 504 | 
            -
                        image,
         | 
| 505 | 
            -
                        latent_timestep,
         | 
| 506 | 
            -
                        batch_size * num_images_per_prompt,
         | 
| 507 | 
            -
                        num_channels_latents,
         | 
| 508 | 
            -
                        height,
         | 
| 509 | 
            -
                        width,
         | 
| 510 | 
            -
                        prompt_embeds.dtype,
         | 
| 511 | 
            -
                        device,
         | 
| 512 | 
            -
                        latents,
         | 
| 513 | 
            -
                    )
         | 
| 514 | 
            -
                    bs = batch_size * num_images_per_prompt
         | 
| 515 | 
            -
             | 
| 516 | 
            -
                    # 6. Get Guidance Scale Embedding
         | 
| 517 | 
            -
                    w = torch.tensor(guidance_scale).repeat(bs)
         | 
| 518 | 
            -
                    w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
         | 
| 519 | 
            -
                        device=device, dtype=latents.dtype
         | 
| 520 | 
            -
                    )
         | 
| 521 | 
            -
                    controlnet_keep = []
         | 
| 522 | 
            -
                    for i in range(len(timesteps)):
         | 
| 523 | 
            -
                        keeps = [
         | 
| 524 | 
            -
                            1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
         | 
| 525 | 
            -
                            for s, e in zip(control_guidance_start, control_guidance_end)
         | 
| 526 | 
            -
                        ]
         | 
| 527 | 
            -
                        controlnet_keep.append(
         | 
| 528 | 
            -
                            keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
         | 
| 529 | 
            -
                        )
         | 
| 530 | 
            -
                    # 7. LCM MultiStep Sampling Loop:
         | 
| 531 | 
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 532 | 
            -
                        for i, t in enumerate(timesteps):
         | 
| 533 | 
            -
                            ts = torch.full((bs,), t, device=device, dtype=torch.long)
         | 
| 534 | 
            -
                            latents = latents.to(prompt_embeds.dtype)
         | 
| 535 | 
            -
                            if guess_mode:
         | 
| 536 | 
            -
                                # Infer ControlNet only for the conditional batch.
         | 
| 537 | 
            -
                                control_model_input = latents
         | 
| 538 | 
            -
                                control_model_input = self.scheduler.scale_model_input(
         | 
| 539 | 
            -
                                    control_model_input, ts
         | 
| 540 | 
            -
                                )
         | 
| 541 | 
            -
                                controlnet_prompt_embeds = prompt_embeds
         | 
| 542 | 
            -
                            else:
         | 
| 543 | 
            -
                                control_model_input = latents
         | 
| 544 | 
            -
                                controlnet_prompt_embeds = prompt_embeds
         | 
| 545 | 
            -
                            if isinstance(controlnet_keep[i], list):
         | 
| 546 | 
            -
                                cond_scale = [
         | 
| 547 | 
            -
                                    c * s
         | 
| 548 | 
            -
                                    for c, s in zip(
         | 
| 549 | 
            -
                                        controlnet_conditioning_scale, controlnet_keep[i]
         | 
| 550 | 
            -
                                    )
         | 
| 551 | 
            -
                                ]
         | 
| 552 | 
            -
                            else:
         | 
| 553 | 
            -
                                controlnet_cond_scale = controlnet_conditioning_scale
         | 
| 554 | 
            -
                                if isinstance(controlnet_cond_scale, list):
         | 
| 555 | 
            -
                                    controlnet_cond_scale = controlnet_cond_scale[0]
         | 
| 556 | 
            -
                                cond_scale = controlnet_cond_scale * controlnet_keep[i]
         | 
| 557 | 
            -
             | 
| 558 | 
            -
                            down_block_res_samples, mid_block_res_sample = self.controlnet(
         | 
| 559 | 
            -
                                control_model_input,
         | 
| 560 | 
            -
                                ts,
         | 
| 561 | 
            -
                                encoder_hidden_states=controlnet_prompt_embeds,
         | 
| 562 | 
            -
                                controlnet_cond=control_image,
         | 
| 563 | 
            -
                                conditioning_scale=cond_scale,
         | 
| 564 | 
            -
                                guess_mode=guess_mode,
         | 
| 565 | 
            -
                                return_dict=False,
         | 
| 566 | 
            -
                            )
         | 
| 567 | 
            -
                            # model prediction (v-prediction, eps, x)
         | 
| 568 | 
            -
                            model_pred = self.unet(
         | 
| 569 | 
            -
                                latents,
         | 
| 570 | 
            -
                                ts,
         | 
| 571 | 
            -
                                timestep_cond=w_embedding,
         | 
| 572 | 
            -
                                encoder_hidden_states=prompt_embeds,
         | 
| 573 | 
            -
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 574 | 
            -
                                down_block_additional_residuals=down_block_res_samples,
         | 
| 575 | 
            -
                                mid_block_additional_residual=mid_block_res_sample,
         | 
| 576 | 
            -
                                return_dict=False,
         | 
| 577 | 
            -
                            )[0]
         | 
| 578 | 
            -
             | 
| 579 | 
            -
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 580 | 
            -
                            latents, denoised = self.scheduler.step(
         | 
| 581 | 
            -
                                model_pred, i, t, latents, return_dict=False
         | 
| 582 | 
            -
                            )
         | 
| 583 | 
            -
             | 
| 584 | 
            -
                            # # call the callback, if provided
         | 
| 585 | 
            -
                            # if i == len(timesteps) - 1:
         | 
| 586 | 
            -
                            progress_bar.update()
         | 
| 587 | 
            -
             | 
| 588 | 
            -
                    denoised = denoised.to(prompt_embeds.dtype)
         | 
| 589 | 
            -
                    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
         | 
| 590 | 
            -
                        self.unet.to("cpu")
         | 
| 591 | 
            -
                        self.controlnet.to("cpu")
         | 
| 592 | 
            -
                        torch.cuda.empty_cache()
         | 
| 593 | 
            -
                    if not output_type == "latent":
         | 
| 594 | 
            -
                        image = self.vae.decode(
         | 
| 595 | 
            -
                            denoised / self.vae.config.scaling_factor, return_dict=False
         | 
| 596 | 
            -
                        )[0]
         | 
| 597 | 
            -
                        image, has_nsfw_concept = self.run_safety_checker(
         | 
| 598 | 
            -
                            image, device, prompt_embeds.dtype
         | 
| 599 | 
            -
                        )
         | 
| 600 | 
            -
                    else:
         | 
| 601 | 
            -
                        image = denoised
         | 
| 602 | 
            -
                        has_nsfw_concept = None
         | 
| 603 | 
            -
             | 
| 604 | 
            -
                    if has_nsfw_concept is None:
         | 
| 605 | 
            -
                        do_denormalize = [True] * image.shape[0]
         | 
| 606 | 
            -
                    else:
         | 
| 607 | 
            -
                        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         | 
| 608 | 
            -
             | 
| 609 | 
            -
                    image = self.image_processor.postprocess(
         | 
| 610 | 
            -
                        image, output_type=output_type, do_denormalize=do_denormalize
         | 
| 611 | 
            -
                    )
         | 
| 612 | 
            -
             | 
| 613 | 
            -
                    if not return_dict:
         | 
| 614 | 
            -
                        return (image, has_nsfw_concept)
         | 
| 615 | 
            -
             | 
| 616 | 
            -
                    return StableDiffusionPipelineOutput(
         | 
| 617 | 
            -
                        images=image, nsfw_content_detected=has_nsfw_concept
         | 
| 618 | 
            -
                    )
         | 
| 619 | 
            -
             | 
| 620 | 
            -
             | 
| 621 | 
            -
            @dataclass
         | 
| 622 | 
            -
            # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
         | 
| 623 | 
            -
            class LCMSchedulerOutput(BaseOutput):
         | 
| 624 | 
            -
                """
         | 
| 625 | 
            -
                Output class for the scheduler's `step` function output.
         | 
| 626 | 
            -
                Args:
         | 
| 627 | 
            -
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 628 | 
            -
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 629 | 
            -
                        denoising loop.
         | 
| 630 | 
            -
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 631 | 
            -
                        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
         | 
| 632 | 
            -
                        `pred_original_sample` can be used to preview progress or for guidance.
         | 
| 633 | 
            -
                """
         | 
| 634 | 
            -
             | 
| 635 | 
            -
                prev_sample: torch.FloatTensor
         | 
| 636 | 
            -
                denoised: Optional[torch.FloatTensor] = None
         | 
| 637 | 
            -
             | 
| 638 | 
            -
             | 
| 639 | 
            -
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         | 
| 640 | 
            -
            def betas_for_alpha_bar(
         | 
| 641 | 
            -
                num_diffusion_timesteps,
         | 
| 642 | 
            -
                max_beta=0.999,
         | 
| 643 | 
            -
                alpha_transform_type="cosine",
         | 
| 644 | 
            -
            ):
         | 
| 645 | 
            -
                """
         | 
| 646 | 
            -
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
         | 
| 647 | 
            -
                (1-beta) over time from t = [0,1].
         | 
| 648 | 
            -
                Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
         | 
| 649 | 
            -
                to that part of the diffusion process.
         | 
| 650 | 
            -
                Args:
         | 
| 651 | 
            -
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         | 
| 652 | 
            -
                    max_beta (`float`): the maximum beta to use; use values lower than 1 to
         | 
| 653 | 
            -
                                 prevent singularities.
         | 
| 654 | 
            -
                    alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
         | 
| 655 | 
            -
                                 Choose from `cosine` or `exp`
         | 
| 656 | 
            -
                Returns:
         | 
| 657 | 
            -
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         | 
| 658 | 
            -
                """
         | 
| 659 | 
            -
                if alpha_transform_type == "cosine":
         | 
| 660 | 
            -
             | 
| 661 | 
            -
                    def alpha_bar_fn(t):
         | 
| 662 | 
            -
                        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
         | 
| 663 | 
            -
             | 
| 664 | 
            -
                elif alpha_transform_type == "exp":
         | 
| 665 | 
            -
             | 
| 666 | 
            -
                    def alpha_bar_fn(t):
         | 
| 667 | 
            -
                        return math.exp(t * -12.0)
         | 
| 668 | 
            -
             | 
| 669 | 
            -
                else:
         | 
| 670 | 
            -
                    raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
         | 
| 671 | 
            -
             | 
| 672 | 
            -
                betas = []
         | 
| 673 | 
            -
                for i in range(num_diffusion_timesteps):
         | 
| 674 | 
            -
                    t1 = i / num_diffusion_timesteps
         | 
| 675 | 
            -
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 676 | 
            -
                    betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
         | 
| 677 | 
            -
                return torch.tensor(betas, dtype=torch.float32)
         | 
| 678 | 
            -
             | 
| 679 | 
            -
             | 
| 680 | 
            -
            def rescale_zero_terminal_snr(betas):
         | 
| 681 | 
            -
                """
         | 
| 682 | 
            -
                Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
         | 
| 683 | 
            -
                Args:
         | 
| 684 | 
            -
                    betas (`torch.FloatTensor`):
         | 
| 685 | 
            -
                        the betas that the scheduler is being initialized with.
         | 
| 686 | 
            -
                Returns:
         | 
| 687 | 
            -
                    `torch.FloatTensor`: rescaled betas with zero terminal SNR
         | 
| 688 | 
            -
                """
         | 
| 689 | 
            -
                # Convert betas to alphas_bar_sqrt
         | 
| 690 | 
            -
                alphas = 1.0 - betas
         | 
| 691 | 
            -
                alphas_cumprod = torch.cumprod(alphas, dim=0)
         | 
| 692 | 
            -
                alphas_bar_sqrt = alphas_cumprod.sqrt()
         | 
| 693 | 
            -
             | 
| 694 | 
            -
                # Store old values.
         | 
| 695 | 
            -
                alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
         | 
| 696 | 
            -
                alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
         | 
| 697 | 
            -
             | 
| 698 | 
            -
                # Shift so the last timestep is zero.
         | 
| 699 | 
            -
                alphas_bar_sqrt -= alphas_bar_sqrt_T
         | 
| 700 | 
            -
             | 
| 701 | 
            -
                # Scale so the first timestep is back to the old value.
         | 
| 702 | 
            -
                alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
         | 
| 703 | 
            -
             | 
| 704 | 
            -
                # Convert alphas_bar_sqrt to betas
         | 
| 705 | 
            -
                alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
         | 
| 706 | 
            -
                alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
         | 
| 707 | 
            -
                alphas = torch.cat([alphas_bar[0:1], alphas])
         | 
| 708 | 
            -
                betas = 1 - alphas
         | 
| 709 | 
            -
             | 
| 710 | 
            -
                return betas
         | 
| 711 | 
            -
             | 
| 712 | 
            -
             | 
| 713 | 
            -
            class LCMScheduler_X(SchedulerMixin, ConfigMixin):
         | 
| 714 | 
            -
                """
         | 
| 715 | 
            -
                `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
         | 
| 716 | 
            -
                non-Markovian guidance.
         | 
| 717 | 
            -
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 718 | 
            -
                methods the library implements for all schedulers such as loading and saving.
         | 
| 719 | 
            -
                Args:
         | 
| 720 | 
            -
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 721 | 
            -
                        The number of diffusion steps to train the model.
         | 
| 722 | 
            -
                    beta_start (`float`, defaults to 0.0001):
         | 
| 723 | 
            -
                        The starting `beta` value of inference.
         | 
| 724 | 
            -
                    beta_end (`float`, defaults to 0.02):
         | 
| 725 | 
            -
                        The final `beta` value.
         | 
| 726 | 
            -
                    beta_schedule (`str`, defaults to `"linear"`):
         | 
| 727 | 
            -
                        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
         | 
| 728 | 
            -
                        `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
         | 
| 729 | 
            -
                    trained_betas (`np.ndarray`, *optional*):
         | 
| 730 | 
            -
                        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
         | 
| 731 | 
            -
                    clip_sample (`bool`, defaults to `True`):
         | 
| 732 | 
            -
                        Clip the predicted sample for numerical stability.
         | 
| 733 | 
            -
                    clip_sample_range (`float`, defaults to 1.0):
         | 
| 734 | 
            -
                        The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
         | 
| 735 | 
            -
                    set_alpha_to_one (`bool`, defaults to `True`):
         | 
| 736 | 
            -
                        Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
         | 
| 737 | 
            -
                        there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
         | 
| 738 | 
            -
                        otherwise it uses the alpha value at step 0.
         | 
| 739 | 
            -
                    steps_offset (`int`, defaults to 0):
         | 
| 740 | 
            -
                        An offset added to the inference steps. You can use a combination of `offset=1` and
         | 
| 741 | 
            -
                        `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
         | 
| 742 | 
            -
                        Diffusion.
         | 
| 743 | 
            -
                    prediction_type (`str`, defaults to `epsilon`, *optional*):
         | 
| 744 | 
            -
                        Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
         | 
| 745 | 
            -
                        `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
         | 
| 746 | 
            -
                        Video](https://imagen.research.google/video/paper.pdf) paper).
         | 
| 747 | 
            -
                    thresholding (`bool`, defaults to `False`):
         | 
| 748 | 
            -
                        Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
         | 
| 749 | 
            -
                        as Stable Diffusion.
         | 
| 750 | 
            -
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         | 
| 751 | 
            -
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         | 
| 752 | 
            -
                    sample_max_value (`float`, defaults to 1.0):
         | 
| 753 | 
            -
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
         | 
| 754 | 
            -
                    timestep_spacing (`str`, defaults to `"leading"`):
         | 
| 755 | 
            -
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         | 
| 756 | 
            -
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         | 
| 757 | 
            -
                    rescale_betas_zero_snr (`bool`, defaults to `False`):
         | 
| 758 | 
            -
                        Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
         | 
| 759 | 
            -
                        dark samples instead of limiting it to samples with medium brightness. Loosely related to
         | 
| 760 | 
            -
                        [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
         | 
| 761 | 
            -
                """
         | 
| 762 | 
            -
             | 
| 763 | 
            -
                # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         | 
| 764 | 
            -
                order = 1
         | 
| 765 | 
            -
             | 
| 766 | 
            -
                @register_to_config
         | 
| 767 | 
            -
                def __init__(
         | 
| 768 | 
            -
                    self,
         | 
| 769 | 
            -
                    num_train_timesteps: int = 1000,
         | 
| 770 | 
            -
                    beta_start: float = 0.0001,
         | 
| 771 | 
            -
                    beta_end: float = 0.02,
         | 
| 772 | 
            -
                    beta_schedule: str = "linear",
         | 
| 773 | 
            -
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         | 
| 774 | 
            -
                    clip_sample: bool = True,
         | 
| 775 | 
            -
                    set_alpha_to_one: bool = True,
         | 
| 776 | 
            -
                    steps_offset: int = 0,
         | 
| 777 | 
            -
                    prediction_type: str = "epsilon",
         | 
| 778 | 
            -
                    thresholding: bool = False,
         | 
| 779 | 
            -
                    dynamic_thresholding_ratio: float = 0.995,
         | 
| 780 | 
            -
                    clip_sample_range: float = 1.0,
         | 
| 781 | 
            -
                    sample_max_value: float = 1.0,
         | 
| 782 | 
            -
                    timestep_spacing: str = "leading",
         | 
| 783 | 
            -
                    rescale_betas_zero_snr: bool = False,
         | 
| 784 | 
            -
                ):
         | 
| 785 | 
            -
                    if trained_betas is not None:
         | 
| 786 | 
            -
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         | 
| 787 | 
            -
                    elif beta_schedule == "linear":
         | 
| 788 | 
            -
                        self.betas = torch.linspace(
         | 
| 789 | 
            -
                            beta_start, beta_end, num_train_timesteps, dtype=torch.float32
         | 
| 790 | 
            -
                        )
         | 
| 791 | 
            -
                    elif beta_schedule == "scaled_linear":
         | 
| 792 | 
            -
                        # this schedule is very specific to the latent diffusion model.
         | 
| 793 | 
            -
                        self.betas = (
         | 
| 794 | 
            -
                            torch.linspace(
         | 
| 795 | 
            -
                                beta_start**0.5,
         | 
| 796 | 
            -
                                beta_end**0.5,
         | 
| 797 | 
            -
                                num_train_timesteps,
         | 
| 798 | 
            -
                                dtype=torch.float32,
         | 
| 799 | 
            -
                            )
         | 
| 800 | 
            -
                            ** 2
         | 
| 801 | 
            -
                        )
         | 
| 802 | 
            -
                    elif beta_schedule == "squaredcos_cap_v2":
         | 
| 803 | 
            -
                        # Glide cosine schedule
         | 
| 804 | 
            -
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         | 
| 805 | 
            -
                    else:
         | 
| 806 | 
            -
                        raise NotImplementedError(
         | 
| 807 | 
            -
                            f"{beta_schedule} does is not implemented for {self.__class__}"
         | 
| 808 | 
            -
                        )
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                    # Rescale for zero SNR
         | 
| 811 | 
            -
                    if rescale_betas_zero_snr:
         | 
| 812 | 
            -
                        self.betas = rescale_zero_terminal_snr(self.betas)
         | 
| 813 | 
            -
             | 
| 814 | 
            -
                    self.alphas = 1.0 - self.betas
         | 
| 815 | 
            -
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         | 
| 816 | 
            -
             | 
| 817 | 
            -
                    # At every step in ddim, we are looking into the previous alphas_cumprod
         | 
| 818 | 
            -
                    # For the final step, there is no previous alphas_cumprod because we are already at 0
         | 
| 819 | 
            -
                    # `set_alpha_to_one` decides whether we set this parameter simply to one or
         | 
| 820 | 
            -
                    # whether we use the final alpha of the "non-previous" one.
         | 
| 821 | 
            -
                    self.final_alpha_cumprod = (
         | 
| 822 | 
            -
                        torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
         | 
| 823 | 
            -
                    )
         | 
| 824 | 
            -
             | 
| 825 | 
            -
                    # standard deviation of the initial noise distribution
         | 
| 826 | 
            -
                    self.init_noise_sigma = 1.0
         | 
| 827 | 
            -
             | 
| 828 | 
            -
                    # setable values
         | 
| 829 | 
            -
                    self.num_inference_steps = None
         | 
| 830 | 
            -
                    self.timesteps = torch.from_numpy(
         | 
| 831 | 
            -
                        np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
         | 
| 832 | 
            -
                    )
         | 
| 833 | 
            -
             | 
| 834 | 
            -
                def scale_model_input(
         | 
| 835 | 
            -
                    self, sample: torch.FloatTensor, timestep: Optional[int] = None
         | 
| 836 | 
            -
                ) -> torch.FloatTensor:
         | 
| 837 | 
            -
                    """
         | 
| 838 | 
            -
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 839 | 
            -
                    current timestep.
         | 
| 840 | 
            -
                    Args:
         | 
| 841 | 
            -
                        sample (`torch.FloatTensor`):
         | 
| 842 | 
            -
                            The input sample.
         | 
| 843 | 
            -
                        timestep (`int`, *optional*):
         | 
| 844 | 
            -
                            The current timestep in the diffusion chain.
         | 
| 845 | 
            -
                    Returns:
         | 
| 846 | 
            -
                        `torch.FloatTensor`:
         | 
| 847 | 
            -
                            A scaled input sample.
         | 
| 848 | 
            -
                    """
         | 
| 849 | 
            -
                    return sample
         | 
| 850 | 
            -
             | 
| 851 | 
            -
                def _get_variance(self, timestep, prev_timestep):
         | 
| 852 | 
            -
                    alpha_prod_t = self.alphas_cumprod[timestep]
         | 
| 853 | 
            -
                    alpha_prod_t_prev = (
         | 
| 854 | 
            -
                        self.alphas_cumprod[prev_timestep]
         | 
| 855 | 
            -
                        if prev_timestep >= 0
         | 
| 856 | 
            -
                        else self.final_alpha_cumprod
         | 
| 857 | 
            -
                    )
         | 
| 858 | 
            -
                    beta_prod_t = 1 - alpha_prod_t
         | 
| 859 | 
            -
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         | 
| 860 | 
            -
             | 
| 861 | 
            -
                    variance = (beta_prod_t_prev / beta_prod_t) * (
         | 
| 862 | 
            -
                        1 - alpha_prod_t / alpha_prod_t_prev
         | 
| 863 | 
            -
                    )
         | 
| 864 | 
            -
             | 
| 865 | 
            -
                    return variance
         | 
| 866 | 
            -
             | 
| 867 | 
            -
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         | 
| 868 | 
            -
                def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 869 | 
            -
                    """
         | 
| 870 | 
            -
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         | 
| 871 | 
            -
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         | 
| 872 | 
            -
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         | 
| 873 | 
            -
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         | 
| 874 | 
            -
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         | 
| 875 | 
            -
                    https://arxiv.org/abs/2205.11487
         | 
| 876 | 
            -
                    """
         | 
| 877 | 
            -
                    dtype = sample.dtype
         | 
| 878 | 
            -
                    batch_size, channels, height, width = sample.shape
         | 
| 879 | 
            -
             | 
| 880 | 
            -
                    if dtype not in (torch.float32, torch.float64):
         | 
| 881 | 
            -
                        sample = (
         | 
| 882 | 
            -
                            sample.float()
         | 
| 883 | 
            -
                        )  # upcast for quantile calculation, and clamp not implemented for cpu half
         | 
| 884 | 
            -
             | 
| 885 | 
            -
                    # Flatten sample for doing quantile calculation along each image
         | 
| 886 | 
            -
                    sample = sample.reshape(batch_size, channels * height * width)
         | 
| 887 | 
            -
             | 
| 888 | 
            -
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         | 
| 889 | 
            -
             | 
| 890 | 
            -
                    s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         | 
| 891 | 
            -
                    s = torch.clamp(
         | 
| 892 | 
            -
                        s, min=1, max=self.config.sample_max_value
         | 
| 893 | 
            -
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         | 
| 894 | 
            -
             | 
| 895 | 
            -
                    s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
         | 
| 896 | 
            -
                    sample = (
         | 
| 897 | 
            -
                        torch.clamp(sample, -s, s) / s
         | 
| 898 | 
            -
                    )  # "we threshold xt0 to the range [-s, s] and then divide by s"
         | 
| 899 | 
            -
             | 
| 900 | 
            -
                    sample = sample.reshape(batch_size, channels, height, width)
         | 
| 901 | 
            -
                    sample = sample.to(dtype)
         | 
| 902 | 
            -
             | 
| 903 | 
            -
                    return sample
         | 
| 904 | 
            -
             | 
| 905 | 
            -
                def set_timesteps(
         | 
| 906 | 
            -
                    self,
         | 
| 907 | 
            -
                    stength,
         | 
| 908 | 
            -
                    num_inference_steps: int,
         | 
| 909 | 
            -
                    lcm_origin_steps: int,
         | 
| 910 | 
            -
                    device: Union[str, torch.device] = None,
         | 
| 911 | 
            -
                ):
         | 
| 912 | 
            -
                    """
         | 
| 913 | 
            -
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         | 
| 914 | 
            -
                    Args:
         | 
| 915 | 
            -
                        num_inference_steps (`int`):
         | 
| 916 | 
            -
                            The number of diffusion steps used when generating samples with a pre-trained model.
         | 
| 917 | 
            -
                    """
         | 
| 918 | 
            -
             | 
| 919 | 
            -
                    if num_inference_steps > self.config.num_train_timesteps:
         | 
| 920 | 
            -
                        raise ValueError(
         | 
| 921 | 
            -
                            f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
         | 
| 922 | 
            -
                            f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
         | 
| 923 | 
            -
                            f" maximal {self.config.num_train_timesteps} timesteps."
         | 
| 924 | 
            -
                        )
         | 
| 925 | 
            -
             | 
| 926 | 
            -
                    self.num_inference_steps = num_inference_steps
         | 
| 927 | 
            -
             | 
| 928 | 
            -
                    # LCM Timesteps Setting:  # Linear Spacing
         | 
| 929 | 
            -
                    c = self.config.num_train_timesteps // lcm_origin_steps
         | 
| 930 | 
            -
                    lcm_origin_timesteps = (
         | 
| 931 | 
            -
                        np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
         | 
| 932 | 
            -
                    )  # LCM Training  Steps Schedule
         | 
| 933 | 
            -
                    skipping_step = max(len(lcm_origin_timesteps) // num_inference_steps, 1)
         | 
| 934 | 
            -
                    timesteps = lcm_origin_timesteps[::-skipping_step][
         | 
| 935 | 
            -
                        :num_inference_steps
         | 
| 936 | 
            -
                    ]  # LCM Inference Steps Schedule
         | 
| 937 | 
            -
             | 
| 938 | 
            -
                    self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
         | 
| 939 | 
            -
             | 
| 940 | 
            -
                def get_scalings_for_boundary_condition_discrete(self, t):
         | 
| 941 | 
            -
                    self.sigma_data = 0.5  # Default: 0.5
         | 
| 942 | 
            -
             | 
| 943 | 
            -
                    # By dividing 0.1: This is almost a delta function at t=0.
         | 
| 944 | 
            -
                    c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
         | 
| 945 | 
            -
                    c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
         | 
| 946 | 
            -
                    return c_skip, c_out
         | 
| 947 | 
            -
             | 
| 948 | 
            -
                def step(
         | 
| 949 | 
            -
                    self,
         | 
| 950 | 
            -
                    model_output: torch.FloatTensor,
         | 
| 951 | 
            -
                    timeindex: int,
         | 
| 952 | 
            -
                    timestep: int,
         | 
| 953 | 
            -
                    sample: torch.FloatTensor,
         | 
| 954 | 
            -
                    eta: float = 0.0,
         | 
| 955 | 
            -
                    use_clipped_model_output: bool = False,
         | 
| 956 | 
            -
                    generator=None,
         | 
| 957 | 
            -
                    variance_noise: Optional[torch.FloatTensor] = None,
         | 
| 958 | 
            -
                    return_dict: bool = True,
         | 
| 959 | 
            -
                ) -> Union[LCMSchedulerOutput, Tuple]:
         | 
| 960 | 
            -
                    """
         | 
| 961 | 
            -
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 962 | 
            -
                    process from the learned model outputs (most often the predicted noise).
         | 
| 963 | 
            -
                    Args:
         | 
| 964 | 
            -
                        model_output (`torch.FloatTensor`):
         | 
| 965 | 
            -
                            The direct output from learned diffusion model.
         | 
| 966 | 
            -
                        timestep (`float`):
         | 
| 967 | 
            -
                            The current discrete timestep in the diffusion chain.
         | 
| 968 | 
            -
                        sample (`torch.FloatTensor`):
         | 
| 969 | 
            -
                            A current instance of a sample created by the diffusion process.
         | 
| 970 | 
            -
                        eta (`float`):
         | 
| 971 | 
            -
                            The weight of noise for added noise in diffusion step.
         | 
| 972 | 
            -
                        use_clipped_model_output (`bool`, defaults to `False`):
         | 
| 973 | 
            -
                            If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
         | 
| 974 | 
            -
                            because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
         | 
| 975 | 
            -
                            clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
         | 
| 976 | 
            -
                            `use_clipped_model_output` has no effect.
         | 
| 977 | 
            -
                        generator (`torch.Generator`, *optional*):
         | 
| 978 | 
            -
                            A random number generator.
         | 
| 979 | 
            -
                        variance_noise (`torch.FloatTensor`):
         | 
| 980 | 
            -
                            Alternative to generating noise with `generator` by directly providing the noise for the variance
         | 
| 981 | 
            -
                            itself. Useful for methods such as [`CycleDiffusion`].
         | 
| 982 | 
            -
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 983 | 
            -
                            Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
         | 
| 984 | 
            -
                    Returns:
         | 
| 985 | 
            -
                        [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
         | 
| 986 | 
            -
                            If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
         | 
| 987 | 
            -
                            tuple is returned where the first element is the sample tensor.
         | 
| 988 | 
            -
                    """
         | 
| 989 | 
            -
                    if self.num_inference_steps is None:
         | 
| 990 | 
            -
                        raise ValueError(
         | 
| 991 | 
            -
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 992 | 
            -
                        )
         | 
| 993 | 
            -
             | 
| 994 | 
            -
                    # 1. get previous step value
         | 
| 995 | 
            -
                    prev_timeindex = timeindex + 1
         | 
| 996 | 
            -
                    if prev_timeindex < len(self.timesteps):
         | 
| 997 | 
            -
                        prev_timestep = self.timesteps[prev_timeindex]
         | 
| 998 | 
            -
                    else:
         | 
| 999 | 
            -
                        prev_timestep = timestep
         | 
| 1000 | 
            -
             | 
| 1001 | 
            -
                    # 2. compute alphas, betas
         | 
| 1002 | 
            -
                    alpha_prod_t = self.alphas_cumprod[timestep]
         | 
| 1003 | 
            -
                    alpha_prod_t_prev = (
         | 
| 1004 | 
            -
                        self.alphas_cumprod[prev_timestep]
         | 
| 1005 | 
            -
                        if prev_timestep >= 0
         | 
| 1006 | 
            -
                        else self.final_alpha_cumprod
         | 
| 1007 | 
            -
                    )
         | 
| 1008 | 
            -
             | 
| 1009 | 
            -
                    beta_prod_t = 1 - alpha_prod_t
         | 
| 1010 | 
            -
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         | 
| 1011 | 
            -
             | 
| 1012 | 
            -
                    # 3. Get scalings for boundary conditions
         | 
| 1013 | 
            -
                    c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
         | 
| 1014 | 
            -
             | 
| 1015 | 
            -
                    # 4. Different Parameterization:
         | 
| 1016 | 
            -
                    parameterization = self.config.prediction_type
         | 
| 1017 | 
            -
             | 
| 1018 | 
            -
                    if parameterization == "epsilon":  # noise-prediction
         | 
| 1019 | 
            -
                        pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
         | 
| 1020 | 
            -
             | 
| 1021 | 
            -
                    elif parameterization == "sample":  # x-prediction
         | 
| 1022 | 
            -
                        pred_x0 = model_output
         | 
| 1023 | 
            -
             | 
| 1024 | 
            -
                    elif parameterization == "v_prediction":  # v-prediction
         | 
| 1025 | 
            -
                        pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
         | 
| 1026 | 
            -
             | 
| 1027 | 
            -
                    # 4. Denoise model output using boundary conditions
         | 
| 1028 | 
            -
                    denoised = c_out * pred_x0 + c_skip * sample
         | 
| 1029 | 
            -
             | 
| 1030 | 
            -
                    # 5. Sample z ~ N(0, I), For MultiStep Inference
         | 
| 1031 | 
            -
                    # Noise is not used for one-step sampling.
         | 
| 1032 | 
            -
                    if len(self.timesteps) > 1:
         | 
| 1033 | 
            -
                        noise = torch.randn(model_output.shape).to(model_output.device)
         | 
| 1034 | 
            -
                        prev_sample = (
         | 
| 1035 | 
            -
                            alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
         | 
| 1036 | 
            -
                        )
         | 
| 1037 | 
            -
                    else:
         | 
| 1038 | 
            -
                        prev_sample = denoised
         | 
| 1039 | 
            -
             | 
| 1040 | 
            -
                    if not return_dict:
         | 
| 1041 | 
            -
                        return (prev_sample, denoised)
         | 
| 1042 | 
            -
             | 
| 1043 | 
            -
                    return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
         | 
| 1044 | 
            -
             | 
| 1045 | 
            -
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
         | 
| 1046 | 
            -
                def add_noise(
         | 
| 1047 | 
            -
                    self,
         | 
| 1048 | 
            -
                    original_samples: torch.FloatTensor,
         | 
| 1049 | 
            -
                    noise: torch.FloatTensor,
         | 
| 1050 | 
            -
                    timesteps: torch.IntTensor,
         | 
| 1051 | 
            -
                ) -> torch.FloatTensor:
         | 
| 1052 | 
            -
                    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         | 
| 1053 | 
            -
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 1054 | 
            -
                        device=original_samples.device, dtype=original_samples.dtype
         | 
| 1055 | 
            -
                    )
         | 
| 1056 | 
            -
                    timesteps = timesteps.to(original_samples.device)
         | 
| 1057 | 
            -
             | 
| 1058 | 
            -
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 1059 | 
            -
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 1060 | 
            -
                    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
         | 
| 1061 | 
            -
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 1062 | 
            -
             | 
| 1063 | 
            -
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 1064 | 
            -
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 1065 | 
            -
                    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
         | 
| 1066 | 
            -
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 1067 | 
            -
             | 
| 1068 | 
            -
                    noisy_samples = (
         | 
| 1069 | 
            -
                        sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
         | 
| 1070 | 
            -
                    )
         | 
| 1071 | 
            -
                    return noisy_samples
         | 
| 1072 | 
            -
             | 
| 1073 | 
            -
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
         | 
| 1074 | 
            -
                def get_velocity(
         | 
| 1075 | 
            -
                    self,
         | 
| 1076 | 
            -
                    sample: torch.FloatTensor,
         | 
| 1077 | 
            -
                    noise: torch.FloatTensor,
         | 
| 1078 | 
            -
                    timesteps: torch.IntTensor,
         | 
| 1079 | 
            -
                ) -> torch.FloatTensor:
         | 
| 1080 | 
            -
                    # Make sure alphas_cumprod and timestep have same device and dtype as sample
         | 
| 1081 | 
            -
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 1082 | 
            -
                        device=sample.device, dtype=sample.dtype
         | 
| 1083 | 
            -
                    )
         | 
| 1084 | 
            -
                    timesteps = timesteps.to(sample.device)
         | 
| 1085 | 
            -
             | 
| 1086 | 
            -
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 1087 | 
            -
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 1088 | 
            -
                    while len(sqrt_alpha_prod.shape) < len(sample.shape):
         | 
| 1089 | 
            -
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 1090 | 
            -
             | 
| 1091 | 
            -
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 1092 | 
            -
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 1093 | 
            -
                    while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
         | 
| 1094 | 
            -
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 1095 | 
            -
             | 
| 1096 | 
            -
                    velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
         | 
| 1097 | 
            -
                    return velocity
         | 
| 1098 | 
            -
             | 
| 1099 | 
            -
                def __len__(self):
         | 
| 1100 | 
            -
                    return self.config.num_train_timesteps
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        pipelines/controlnet.py
    CHANGED
    
    | @@ -1,8 +1,11 @@ | |
| 1 | 
            -
            from diffusers import  | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
|  | |
|  | |
| 4 | 
             
            from compel import Compel
         | 
| 5 | 
             
            import torch
         | 
|  | |
| 6 |  | 
| 7 | 
             
            try:
         | 
| 8 | 
             
                import intel_extension_for_pytorch as ipex  # type: ignore
         | 
| @@ -11,80 +14,202 @@ except: | |
| 11 |  | 
| 12 | 
             
            import psutil
         | 
| 13 | 
             
            from config import Args
         | 
| 14 | 
            -
            from pydantic import BaseModel
         | 
| 15 | 
             
            from PIL import Image
         | 
| 16 | 
            -
            from typing import Callable
         | 
| 17 |  | 
| 18 | 
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
|  | |
|  | |
| 21 |  | 
| 22 |  | 
| 23 | 
             
            class Pipeline:
         | 
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
                class InputParams(BaseModel):
         | 
| 25 | 
            -
                     | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
                     | 
| 31 | 
            -
                     | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                     | 
| 37 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 | 
             
                    if args.safety_checker:
         | 
| 39 | 
            -
                        pipe =  | 
|  | |
|  | |
| 40 | 
             
                    else:
         | 
| 41 | 
            -
                        pipe =  | 
|  | |
|  | |
|  | |
|  | |
| 42 | 
             
                    if args.use_taesd:
         | 
| 43 | 
            -
                        pipe.vae = AutoencoderTiny.from_pretrained(
         | 
| 44 | 
            -
                             | 
| 45 | 
             
                        )
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                    pipe.set_progress_bar_config(disable=True)
         | 
| 48 | 
            -
                    pipe.to(device=device, dtype=torch_dtype)
         | 
| 49 | 
            -
                    pipe.unet.to(memory_format=torch.channels_last)
         | 
| 50 |  | 
| 51 | 
             
                    # check if computer has less than 64GB of RAM using sys or os
         | 
| 52 | 
             
                    if psutil.virtual_memory().total < 64 * 1024**3:
         | 
| 53 | 
            -
                        pipe.enable_attention_slicing()
         | 
| 54 |  | 
| 55 | 
             
                    if args.torch_compile:
         | 
| 56 | 
            -
                        pipe.unet = torch.compile( | 
| 57 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 58 |  | 
| 59 | 
            -
                        pipe( | 
|  | |
|  | |
|  | |
|  | |
| 60 |  | 
| 61 | 
            -
                    compel_proc = Compel(
         | 
| 62 | 
            -
                        tokenizer=pipe.tokenizer,
         | 
| 63 | 
            -
                        text_encoder=pipe.text_encoder,
         | 
| 64 | 
             
                        truncate_long_prompts=False,
         | 
| 65 | 
             
                    )
         | 
| 66 |  | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
                         | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
                         | 
| 81 | 
            -
                         | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
                         | 
| 86 | 
            -
                         | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 89 |  | 
| 90 | 
            -
                    return  | 
|  | |
| 1 | 
            +
            from diffusers import (
         | 
| 2 | 
            +
                StableDiffusionControlNetImg2ImgPipeline,
         | 
| 3 | 
            +
                AutoencoderTiny,
         | 
| 4 | 
            +
                ControlNetModel,
         | 
| 5 | 
            +
            )
         | 
| 6 | 
             
            from compel import Compel
         | 
| 7 | 
             
            import torch
         | 
| 8 | 
            +
            from pipelines.utils.canny_gpu import SobelOperator
         | 
| 9 |  | 
| 10 | 
             
            try:
         | 
| 11 | 
             
                import intel_extension_for_pytorch as ipex  # type: ignore
         | 
|  | |
| 14 |  | 
| 15 | 
             
            import psutil
         | 
| 16 | 
             
            from config import Args
         | 
| 17 | 
            +
            from pydantic import BaseModel, Field
         | 
| 18 | 
             
            from PIL import Image
         | 
|  | |
| 19 |  | 
| 20 | 
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         | 
| 21 | 
            +
            taesd_model = "madebyollin/taesd"
         | 
| 22 | 
            +
            controlnet_model = "lllyasviel/control_v11p_sd15_canny"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
             
            class Pipeline:
         | 
| 28 | 
            +
                class Info(BaseModel):
         | 
| 29 | 
            +
                    name: str = "txt2img"
         | 
| 30 | 
            +
                    description: str = "Generates an image from a text prompt"
         | 
| 31 | 
            +
             | 
| 32 | 
             
                class InputParams(BaseModel):
         | 
| 33 | 
            +
                    prompt: str = Field(
         | 
| 34 | 
            +
                        default_prompt,
         | 
| 35 | 
            +
                        title="Prompt",
         | 
| 36 | 
            +
                        field="textarea",
         | 
| 37 | 
            +
                        id="prompt",
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    seed: int = Field(
         | 
| 40 | 
            +
                        2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                    steps: int = Field(
         | 
| 43 | 
            +
                        4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
                    width: int = Field(
         | 
| 46 | 
            +
                        512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    height: int = Field(
         | 
| 49 | 
            +
                        512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    guidance_scale: float = Field(
         | 
| 52 | 
            +
                        0.2,
         | 
| 53 | 
            +
                        min=0,
         | 
| 54 | 
            +
                        max=2,
         | 
| 55 | 
            +
                        step=0.001,
         | 
| 56 | 
            +
                        title="Guidance Scale",
         | 
| 57 | 
            +
                        field="range",
         | 
| 58 | 
            +
                        hide=True,
         | 
| 59 | 
            +
                        id="guidance_scale",
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    strength: float = Field(
         | 
| 62 | 
            +
                        0.5,
         | 
| 63 | 
            +
                        min=0.25,
         | 
| 64 | 
            +
                        max=1.0,
         | 
| 65 | 
            +
                        step=0.001,
         | 
| 66 | 
            +
                        title="Strength",
         | 
| 67 | 
            +
                        field="range",
         | 
| 68 | 
            +
                        hide=True,
         | 
| 69 | 
            +
                        id="strength",
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                    controlnet_scale: float = Field(
         | 
| 72 | 
            +
                        0.8,
         | 
| 73 | 
            +
                        min=0,
         | 
| 74 | 
            +
                        max=1.0,
         | 
| 75 | 
            +
                        step=0.001,
         | 
| 76 | 
            +
                        title="Controlnet Scale",
         | 
| 77 | 
            +
                        field="range",
         | 
| 78 | 
            +
                        hide=True,
         | 
| 79 | 
            +
                        id="controlnet_scale",
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
                    controlnet_start: float = Field(
         | 
| 82 | 
            +
                        0.0,
         | 
| 83 | 
            +
                        min=0,
         | 
| 84 | 
            +
                        max=1.0,
         | 
| 85 | 
            +
                        step=0.001,
         | 
| 86 | 
            +
                        title="Controlnet Start",
         | 
| 87 | 
            +
                        field="range",
         | 
| 88 | 
            +
                        hide=True,
         | 
| 89 | 
            +
                        id="controlnet_start",
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    controlnet_end: float = Field(
         | 
| 92 | 
            +
                        1.0,
         | 
| 93 | 
            +
                        min=0,
         | 
| 94 | 
            +
                        max=1.0,
         | 
| 95 | 
            +
                        step=0.001,
         | 
| 96 | 
            +
                        title="Controlnet End",
         | 
| 97 | 
            +
                        field="range",
         | 
| 98 | 
            +
                        hide=True,
         | 
| 99 | 
            +
                        id="controlnet_end",
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    canny_low_threshold: float = Field(
         | 
| 102 | 
            +
                        0.31,
         | 
| 103 | 
            +
                        min=0,
         | 
| 104 | 
            +
                        max=1.0,
         | 
| 105 | 
            +
                        step=0.001,
         | 
| 106 | 
            +
                        title="Canny Low Threshold",
         | 
| 107 | 
            +
                        field="range",
         | 
| 108 | 
            +
                        hide=True,
         | 
| 109 | 
            +
                        id="canny_low_threshold",
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
                    canny_high_threshold: float = Field(
         | 
| 112 | 
            +
                        0.125,
         | 
| 113 | 
            +
                        min=0,
         | 
| 114 | 
            +
                        max=1.0,
         | 
| 115 | 
            +
                        step=0.001,
         | 
| 116 | 
            +
                        title="Canny High Threshold",
         | 
| 117 | 
            +
                        field="range",
         | 
| 118 | 
            +
                        hide=True,
         | 
| 119 | 
            +
                        id="canny_high_threshold",
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    debug_canny: bool = Field(
         | 
| 122 | 
            +
                        False,
         | 
| 123 | 
            +
                        title="Debug Canny",
         | 
| 124 | 
            +
                        field="checkbox",
         | 
| 125 | 
            +
                        hide=True,
         | 
| 126 | 
            +
                        id="debug_canny",
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    image: bool = True
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         | 
| 131 | 
            +
                    controlnet_canny = ControlNetModel.from_pretrained(
         | 
| 132 | 
            +
                        controlnet_model, torch_dtype=torch_dtype
         | 
| 133 | 
            +
                    ).to(device)
         | 
| 134 | 
             
                    if args.safety_checker:
         | 
| 135 | 
            +
                        self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
         | 
| 136 | 
            +
                            base_model, controlnet=controlnet_canny
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
             
                    else:
         | 
| 139 | 
            +
                        self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
         | 
| 140 | 
            +
                            base_model,
         | 
| 141 | 
            +
                            safety_checker=None,
         | 
| 142 | 
            +
                            controlnet=controlnet_canny,
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
             
                    if args.use_taesd:
         | 
| 145 | 
            +
                        self.pipe.vae = AutoencoderTiny.from_pretrained(
         | 
| 146 | 
            +
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         | 
| 147 | 
             
                        )
         | 
| 148 | 
            +
                    self.canny_torch = SobelOperator(device=device)
         | 
| 149 | 
            +
                    self.pipe.set_progress_bar_config(disable=True)
         | 
| 150 | 
            +
                    self.pipe.to(device=device, dtype=torch_dtype)
         | 
| 151 | 
            +
                    self.pipe.unet.to(memory_format=torch.channels_last)
         | 
| 152 |  | 
| 153 | 
             
                    # check if computer has less than 64GB of RAM using sys or os
         | 
| 154 | 
             
                    if psutil.virtual_memory().total < 64 * 1024**3:
         | 
| 155 | 
            +
                        self.pipe.enable_attention_slicing()
         | 
| 156 |  | 
| 157 | 
             
                    if args.torch_compile:
         | 
| 158 | 
            +
                        self.pipe.unet = torch.compile(
         | 
| 159 | 
            +
                            self.pipe.unet, mode="reduce-overhead", fullgraph=True
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                        self.pipe.vae = torch.compile(
         | 
| 162 | 
            +
                            self.pipe.vae, mode="reduce-overhead", fullgraph=True
         | 
| 163 | 
            +
                        )
         | 
| 164 |  | 
| 165 | 
            +
                        self.pipe(
         | 
| 166 | 
            +
                            prompt="warmup",
         | 
| 167 | 
            +
                            image=[Image.new("RGB", (768, 768))],
         | 
| 168 | 
            +
                            control_image=[Image.new("RGB", (768, 768))],
         | 
| 169 | 
            +
                        )
         | 
| 170 |  | 
| 171 | 
            +
                    self.compel_proc = Compel(
         | 
| 172 | 
            +
                        tokenizer=self.pipe.tokenizer,
         | 
| 173 | 
            +
                        text_encoder=self.pipe.text_encoder,
         | 
| 174 | 
             
                        truncate_long_prompts=False,
         | 
| 175 | 
             
                    )
         | 
| 176 |  | 
| 177 | 
            +
                def predict(self, params: "Pipeline.InputParams") -> Image.Image:
         | 
| 178 | 
            +
                    generator = torch.manual_seed(params.seed)
         | 
| 179 | 
            +
                    prompt_embeds = self.compel_proc(params.prompt)
         | 
| 180 | 
            +
                    control_image = self.canny_torch(
         | 
| 181 | 
            +
                        params.image, params.canny_low_threshold, params.canny_high_threshold
         | 
| 182 | 
            +
                    )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    results = self.pipe(
         | 
| 185 | 
            +
                        image=params.image,
         | 
| 186 | 
            +
                        control_image=control_image,
         | 
| 187 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 188 | 
            +
                        generator=generator,
         | 
| 189 | 
            +
                        strength=params.strength,
         | 
| 190 | 
            +
                        num_inference_steps=params.steps,
         | 
| 191 | 
            +
                        guidance_scale=params.guidance_scale,
         | 
| 192 | 
            +
                        width=params.width,
         | 
| 193 | 
            +
                        height=params.height,
         | 
| 194 | 
            +
                        output_type="pil",
         | 
| 195 | 
            +
                        controlnet_conditioning_scale=params.controlnet_scale,
         | 
| 196 | 
            +
                        control_guidance_start=params.controlnet_start,
         | 
| 197 | 
            +
                        control_guidance_end=params.controlnet_end,
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    nsfw_content_detected = (
         | 
| 201 | 
            +
                        results.nsfw_content_detected[0]
         | 
| 202 | 
            +
                        if "nsfw_content_detected" in results
         | 
| 203 | 
            +
                        else False
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                    if nsfw_content_detected:
         | 
| 206 | 
            +
                        return None
         | 
| 207 | 
            +
                    result_image = results.images[0]
         | 
| 208 | 
            +
                    if params.debug_canny:
         | 
| 209 | 
            +
                        # paste control_image on top of result_image
         | 
| 210 | 
            +
                        w0, h0 = (200, 200)
         | 
| 211 | 
            +
                        control_image = control_image.resize((w0, h0))
         | 
| 212 | 
            +
                        w1, h1 = result_image.size
         | 
| 213 | 
            +
                        result_image.paste(control_image, (w1 - w0, h1 - h0))
         | 
| 214 |  | 
| 215 | 
            +
                    return result_image
         | 
    	
        pipelines/txt2img.py
    CHANGED
    
    | @@ -11,7 +11,6 @@ import psutil | |
| 11 | 
             
            from config import Args
         | 
| 12 | 
             
            from pydantic import BaseModel, Field
         | 
| 13 | 
             
            from PIL import Image
         | 
| 14 | 
            -
            from typing import Callable
         | 
| 15 |  | 
| 16 | 
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         | 
| 17 | 
             
            taesd_model = "madebyollin/taesd"
         | 
| @@ -29,22 +28,19 @@ class Pipeline: | |
| 29 | 
             
                        default_prompt,
         | 
| 30 | 
             
                        title="Prompt",
         | 
| 31 | 
             
                        field="textarea",
         | 
|  | |
| 32 | 
             
                    )
         | 
| 33 | 
            -
                    seed: int = Field( | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
                        max= | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                         | 
| 41 | 
            -
                        hide=True,
         | 
| 42 | 
             
                    )
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                    steps: int = Field(4, min=2, max=15, title="Steps", field="range", hide=True)
         | 
| 45 | 
            -
                    width: int = Field(512, min=2, max=15, title="Width", disabled=True, hide=True)
         | 
| 46 | 
             
                    height: int = Field(
         | 
| 47 | 
            -
                        512, min=2, max=15, title="Height", disabled=True, hide=True
         | 
| 48 | 
             
                    )
         | 
| 49 | 
             
                    guidance_scale: float = Field(
         | 
| 50 | 
             
                        8.0,
         | 
| @@ -54,6 +50,10 @@ class Pipeline: | |
| 54 | 
             
                        title="Guidance Scale",
         | 
| 55 | 
             
                        field="range",
         | 
| 56 | 
             
                        hide=True,
         | 
|  | |
|  | |
|  | |
|  | |
| 57 | 
             
                    )
         | 
| 58 |  | 
| 59 | 
             
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         | 
|  | |
| 11 | 
             
            from config import Args
         | 
| 12 | 
             
            from pydantic import BaseModel, Field
         | 
| 13 | 
             
            from PIL import Image
         | 
|  | |
| 14 |  | 
| 15 | 
             
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         | 
| 16 | 
             
            taesd_model = "madebyollin/taesd"
         | 
|  | |
| 28 | 
             
                        default_prompt,
         | 
| 29 | 
             
                        title="Prompt",
         | 
| 30 | 
             
                        field="textarea",
         | 
| 31 | 
            +
                        id="prompt",
         | 
| 32 | 
             
                    )
         | 
| 33 | 
            +
                    seed: int = Field(
         | 
| 34 | 
            +
                        2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    steps: int = Field(
         | 
| 37 | 
            +
                        4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    width: int = Field(
         | 
| 40 | 
            +
                        512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
         | 
|  | |
| 41 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 42 | 
             
                    height: int = Field(
         | 
| 43 | 
            +
                        512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
         | 
| 44 | 
             
                    )
         | 
| 45 | 
             
                    guidance_scale: float = Field(
         | 
| 46 | 
             
                        8.0,
         | 
|  | |
| 50 | 
             
                        title="Guidance Scale",
         | 
| 51 | 
             
                        field="range",
         | 
| 52 | 
             
                        hide=True,
         | 
| 53 | 
            +
                        id="guidance_scale",
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    image: bool = Field(
         | 
| 56 | 
            +
                        True, title="Image", field="checkbox", hide=True, id="image"
         | 
| 57 | 
             
                    )
         | 
| 58 |  | 
| 59 | 
             
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         | 
    	
        canny_gpu.py → pipelines/utils/canny_gpu.py
    RENAMED
    
    | 
            File without changes
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            diffusers | 
| 2 | 
             
            transformers==4.34.1
         | 
| 3 | 
             
            gradio==3.50.2
         | 
| 4 | 
             
            --extra-index-url https://download.pytorch.org/whl/cu121;
         | 
|  | |
| 1 | 
            +
            git+https://github.com/huggingface/diffusers@c697f524761abd2314c030221a3ad2f7791eab4e
         | 
| 2 | 
             
            transformers==4.34.1
         | 
| 3 | 
             
            gradio==3.50.2
         | 
| 4 | 
             
            --extra-index-url https://download.pytorch.org/whl/cu121;
         | 
    	
        user_queue.py
    CHANGED
    
    | @@ -1,18 +1,29 @@ | |
| 1 | 
             
            from typing import Dict, Union
         | 
| 2 | 
             
            from uuid import UUID
         | 
| 3 | 
            -
             | 
| 4 | 
             
            from PIL import Image
         | 
| 5 | 
            -
            from typing import  | 
| 6 | 
            -
            from uuid import UUID
         | 
| 7 | 
            -
            from asyncio import Queue
         | 
| 8 | 
             
            from PIL import Image
         | 
| 9 |  | 
|  | |
| 10 | 
             
            UserId = UUID
         | 
|  | |
| 11 |  | 
| 12 | 
            -
            InputParams = dict
         | 
| 13 |  | 
| 14 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
            -
            UserQueueDict = Dict[UserId, Queue[QueueContent]]
         | 
| 17 |  | 
| 18 | 
            -
             | 
|  | 
|  | |
| 1 | 
             
            from typing import Dict, Union
         | 
| 2 | 
             
            from uuid import UUID
         | 
| 3 | 
            +
            import asyncio
         | 
| 4 | 
             
            from PIL import Image
         | 
| 5 | 
            +
            from typing import Dict, Union
         | 
|  | |
|  | |
| 6 | 
             
            from PIL import Image
         | 
| 7 |  | 
| 8 | 
            +
            InputParams = dict
         | 
| 9 | 
             
            UserId = UUID
         | 
| 10 | 
            +
            EventDataContent = Dict[str, InputParams]
         | 
| 11 |  | 
|  | |
| 12 |  | 
| 13 | 
            +
            class UserDataEvent:
         | 
| 14 | 
            +
                def __init__(self):
         | 
| 15 | 
            +
                    self.data_event = asyncio.Event()
         | 
| 16 | 
            +
                    self.data_content: EventDataContent = {}
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def update_data(self, new_data: EventDataContent):
         | 
| 19 | 
            +
                    self.data_content = new_data
         | 
| 20 | 
            +
                    self.data_event.set()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                async def wait_for_data(self) -> EventDataContent:
         | 
| 23 | 
            +
                    await self.data_event.wait()
         | 
| 24 | 
            +
                    self.data_event.clear()
         | 
| 25 | 
            +
                    return self.data_content
         | 
| 26 |  | 
|  | |
| 27 |  | 
| 28 | 
            +
            UserDataEventMap = Dict[UserId, UserDataEvent]
         | 
| 29 | 
            +
            user_data_events: UserDataEventMap = {}
         | 
 
			
