File size: 3,326 Bytes
8d7f55c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import dataclasses

import pipecat.frames.protobufs.frames_pb2 as frame_protos

from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame
from pipecat.serializers.base_serializer import FrameSerializer

from loguru import logger


class ProtobufFrameSerializer(FrameSerializer):
    SERIALIZABLE_TYPES = {
        TextFrame: "text",
        AudioRawFrame: "audio",
        TranscriptionFrame: "transcription"
    }

    SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}

    def __init__(self):
        pass

    def serialize(self, frame: Frame) -> str | bytes | None:
        proto_frame = frame_protos.Frame()
        if type(frame) not in self.SERIALIZABLE_TYPES:
            raise ValueError(
                f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.")

        # ignoring linter errors; we check that type(frame) is in this dict above
        proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)]  # type: ignore
        for field in dataclasses.fields(frame):  # type: ignore
            setattr(getattr(proto_frame, proto_optional_name), field.name,
                    getattr(frame, field.name))

        result = proto_frame.SerializeToString()
        return result

    def deserialize(self, data: str | bytes) -> Frame | None:
        """Returns a Frame object from a Frame protobuf. Used to convert frames

        passed over the wire as protobufs to Frame objects used in pipelines

        and frame processors.



        >>> serializer = ProtobufFrameSerializer()

        >>> serializer.deserialize(

        ...     serializer.serialize(AudioFrame(data=b'1234567890')))

        AudioFrame(data=b'1234567890')



        >>> serializer.deserialize(

        ...     serializer.serialize(TextFrame(text='hello world')))

        TextFrame(text='hello world')



        >>> serializer.deserialize(serializer.serialize(TranscriptionFrame(

        ...     text="Hello there!", participantId="123", timestamp="2021-01-01")))

        TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')

        """

        proto = frame_protos.Frame.FromString(data)
        which = proto.WhichOneof("frame")
        if which not in self.SERIALIZABLE_FIELDS:
            logger.error("Unable to deserialize a valid frame")
            return None

        class_name = self.SERIALIZABLE_FIELDS[which]
        args = getattr(proto, which)
        args_dict = {}
        for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
            args_dict[field.name] = getattr(args, field.name)

        # Remove special fields if needed
        id = getattr(args, "id")
        name = getattr(args, "name")
        if not id:
            del args_dict["id"]
        if not name:
            del args_dict["name"]

        # Create the instance
        instance = class_name(**args_dict)

        # Set special fields
        if id:
            setattr(instance, "id", getattr(args, "id"))
        if name:
            setattr(instance, "name", getattr(args, "name"))

        return instance