File size: 931 Bytes
b63cd34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
"""

import contextlib
from unittest.mock import patch

from typing import Any


class CapturedCallException(Exception):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.args = args
        self.kwargs = kwargs


class CapturedCall:
    def __init__(self):
        self.args: tuple[Any, ...] = ()
        self.kwargs: dict[str, Any] = {}


@contextlib.contextmanager
def capture_component_call(
    pipeline: Any,
    component_name: str,
    component_method='forward',
):
    component = getattr(pipeline, component_name)
    captured_call = CapturedCall()

    def capture_call(*args, **kwargs):
        raise CapturedCallException(*args, **kwargs)

    with patch.object(component, component_method, new=capture_call):
        try:
            yield captured_call
        except CapturedCallException as e:
            captured_call.args = e.args
            captured_call.kwargs = e.kwargs