# # Copyright (c) 2024, Daily # # SPDX-License-Identifier: BSD 2-Clause License # import dataclasses import pipecat.frames.protobufs.frames_pb2 as frame_protos from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame from pipecat.serializers.base_serializer import FrameSerializer from loguru import logger class ProtobufFrameSerializer(FrameSerializer): SERIALIZABLE_TYPES = { TextFrame: "text", AudioRawFrame: "audio", TranscriptionFrame: "transcription" } SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()} def __init__(self): pass def serialize(self, frame: Frame) -> str | bytes | None: proto_frame = frame_protos.Frame() if type(frame) not in self.SERIALIZABLE_TYPES: raise ValueError( f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.") # ignoring linter errors; we check that type(frame) is in this dict above proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore for field in dataclasses.fields(frame): # type: ignore setattr(getattr(proto_frame, proto_optional_name), field.name, getattr(frame, field.name)) result = proto_frame.SerializeToString() return result def deserialize(self, data: str | bytes) -> Frame | None: """Returns a Frame object from a Frame protobuf. Used to convert frames passed over the wire as protobufs to Frame objects used in pipelines and frame processors. >>> serializer = ProtobufFrameSerializer() >>> serializer.deserialize( ... serializer.serialize(AudioFrame(data=b'1234567890'))) AudioFrame(data=b'1234567890') >>> serializer.deserialize( ... serializer.serialize(TextFrame(text='hello world'))) TextFrame(text='hello world') >>> serializer.deserialize(serializer.serialize(TranscriptionFrame( ... text="Hello there!", participantId="123", timestamp="2021-01-01"))) TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') """ proto = frame_protos.Frame.FromString(data) which = proto.WhichOneof("frame") if which not in self.SERIALIZABLE_FIELDS: logger.error("Unable to deserialize a valid frame") return None class_name = self.SERIALIZABLE_FIELDS[which] args = getattr(proto, which) args_dict = {} for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields: args_dict[field.name] = getattr(args, field.name) # Remove special fields if needed id = getattr(args, "id") name = getattr(args, "name") if not id: del args_dict["id"] if not name: del args_dict["name"] # Create the instance instance = class_name(**args_dict) # Set special fields if id: setattr(instance, "id", getattr(args, "id")) if name: setattr(instance, "name", getattr(args, "name")) return instance