File size: 1,752 Bytes
8d7f55c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import signal

from pipecat.pipeline.task import PipelineTask
from pipecat.utils.utils import obj_count, obj_id

from loguru import logger


class PipelineRunner:

    def __init__(self, *, name: str | None = None, handle_sigint: bool = True):
        self.id: int = obj_id()
        self.name: str = name or f"{self.__class__.__name__}#{obj_count(self)}"

        self._tasks = {}

        if handle_sigint:
            self._setup_sigint()

    async def run(self, task: PipelineTask):
        logger.debug(f"Runner {self} started running {task}")
        self._tasks[task.name] = task
        await task.run()
        del self._tasks[task.name]
        logger.debug(f"Runner {self} finished running {task}")

    async def stop_when_done(self):
        logger.debug(f"Runner {self} scheduled to stop when all tasks are done")
        await asyncio.gather(*[t.stop_when_done() for t in self._tasks.values()])

    async def cancel(self):
        logger.debug(f"Canceling runner {self}")
        await asyncio.gather(*[t.cancel() for t in self._tasks.values()])

    def _setup_sigint(self):
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(
            signal.SIGINT,
            lambda *args: asyncio.create_task(self._sig_handler())
        )
        loop.add_signal_handler(
            signal.SIGTERM,
            lambda *args: asyncio.create_task(self._sig_handler())
        )

    async def _sig_handler(self):
        logger.warning(f"Interruption detected. Canceling runner {self}")
        await self.cancel()

    def __str__(self):
        return self.name