Spaces:
Paused
Paused
| # Copyright 2019 gRPC authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Invocation-side implementation of gRPC Asyncio Python.""" | |
| import asyncio | |
| import sys | |
| from typing import Any, Iterable, List, Optional, Sequence | |
| import grpc | |
| from grpc import _common | |
| from grpc import _compression | |
| from grpc import _grpcio_metadata | |
| from grpc._cython import cygrpc | |
| from . import _base_call | |
| from . import _base_channel | |
| from ._call import StreamStreamCall | |
| from ._call import StreamUnaryCall | |
| from ._call import UnaryStreamCall | |
| from ._call import UnaryUnaryCall | |
| from ._interceptor import ClientInterceptor | |
| from ._interceptor import InterceptedStreamStreamCall | |
| from ._interceptor import InterceptedStreamUnaryCall | |
| from ._interceptor import InterceptedUnaryStreamCall | |
| from ._interceptor import InterceptedUnaryUnaryCall | |
| from ._interceptor import StreamStreamClientInterceptor | |
| from ._interceptor import StreamUnaryClientInterceptor | |
| from ._interceptor import UnaryStreamClientInterceptor | |
| from ._interceptor import UnaryUnaryClientInterceptor | |
| from ._metadata import Metadata | |
| from ._typing import ChannelArgumentType | |
| from ._typing import DeserializingFunction | |
| from ._typing import MetadataType | |
| from ._typing import RequestIterableType | |
| from ._typing import RequestType | |
| from ._typing import ResponseType | |
| from ._typing import SerializingFunction | |
| from ._utils import _timeout_to_deadline | |
| _USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) | |
| if sys.version_info[1] < 7: | |
| def _all_tasks() -> Iterable[asyncio.Task]: | |
| return asyncio.Task.all_tasks() # pylint: disable=no-member | |
| else: | |
| def _all_tasks() -> Iterable[asyncio.Task]: | |
| return asyncio.all_tasks() | |
| def _augment_channel_arguments( | |
| base_options: ChannelArgumentType, compression: Optional[grpc.Compression] | |
| ): | |
| compression_channel_argument = _compression.create_channel_option( | |
| compression | |
| ) | |
| user_agent_channel_argument = ( | |
| ( | |
| cygrpc.ChannelArgKey.primary_user_agent_string, | |
| _USER_AGENT, | |
| ), | |
| ) | |
| return ( | |
| tuple(base_options) | |
| + compression_channel_argument | |
| + user_agent_channel_argument | |
| ) | |
| class _BaseMultiCallable: | |
| """Base class of all multi callable objects. | |
| Handles the initialization logic and stores common attributes. | |
| """ | |
| _loop: asyncio.AbstractEventLoop | |
| _channel: cygrpc.AioChannel | |
| _method: bytes | |
| _request_serializer: SerializingFunction | |
| _response_deserializer: DeserializingFunction | |
| _interceptors: Optional[Sequence[ClientInterceptor]] | |
| _references: List[Any] | |
| _loop: asyncio.AbstractEventLoop | |
| # pylint: disable=too-many-arguments | |
| def __init__( | |
| self, | |
| channel: cygrpc.AioChannel, | |
| method: bytes, | |
| request_serializer: SerializingFunction, | |
| response_deserializer: DeserializingFunction, | |
| interceptors: Optional[Sequence[ClientInterceptor]], | |
| references: List[Any], | |
| loop: asyncio.AbstractEventLoop, | |
| ) -> None: | |
| self._loop = loop | |
| self._channel = channel | |
| self._method = method | |
| self._request_serializer = request_serializer | |
| self._response_deserializer = response_deserializer | |
| self._interceptors = interceptors | |
| self._references = references | |
| def _init_metadata( | |
| metadata: Optional[MetadataType] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| ) -> Metadata: | |
| """Based on the provided values for <metadata> or <compression> initialise the final | |
| metadata, as it should be used for the current call. | |
| """ | |
| metadata = metadata or Metadata() | |
| if not isinstance(metadata, Metadata) and isinstance(metadata, tuple): | |
| metadata = Metadata.from_tuple(metadata) | |
| if compression: | |
| metadata = Metadata( | |
| *_compression.augment_metadata(metadata, compression) | |
| ) | |
| return metadata | |
| class UnaryUnaryMultiCallable( | |
| _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable | |
| ): | |
| def __call__( | |
| self, | |
| request: RequestType, | |
| *, | |
| timeout: Optional[float] = None, | |
| metadata: Optional[MetadataType] = None, | |
| credentials: Optional[grpc.CallCredentials] = None, | |
| wait_for_ready: Optional[bool] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: | |
| metadata = self._init_metadata(metadata, compression) | |
| if not self._interceptors: | |
| call = UnaryUnaryCall( | |
| request, | |
| _timeout_to_deadline(timeout), | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| else: | |
| call = InterceptedUnaryUnaryCall( | |
| self._interceptors, | |
| request, | |
| timeout, | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| return call | |
| class UnaryStreamMultiCallable( | |
| _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable | |
| ): | |
| def __call__( | |
| self, | |
| request: RequestType, | |
| *, | |
| timeout: Optional[float] = None, | |
| metadata: Optional[MetadataType] = None, | |
| credentials: Optional[grpc.CallCredentials] = None, | |
| wait_for_ready: Optional[bool] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: | |
| metadata = self._init_metadata(metadata, compression) | |
| if not self._interceptors: | |
| call = UnaryStreamCall( | |
| request, | |
| _timeout_to_deadline(timeout), | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| else: | |
| call = InterceptedUnaryStreamCall( | |
| self._interceptors, | |
| request, | |
| timeout, | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| return call | |
| class StreamUnaryMultiCallable( | |
| _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable | |
| ): | |
| def __call__( | |
| self, | |
| request_iterator: Optional[RequestIterableType] = None, | |
| timeout: Optional[float] = None, | |
| metadata: Optional[MetadataType] = None, | |
| credentials: Optional[grpc.CallCredentials] = None, | |
| wait_for_ready: Optional[bool] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| ) -> _base_call.StreamUnaryCall: | |
| metadata = self._init_metadata(metadata, compression) | |
| if not self._interceptors: | |
| call = StreamUnaryCall( | |
| request_iterator, | |
| _timeout_to_deadline(timeout), | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| else: | |
| call = InterceptedStreamUnaryCall( | |
| self._interceptors, | |
| request_iterator, | |
| timeout, | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| return call | |
| class StreamStreamMultiCallable( | |
| _BaseMultiCallable, _base_channel.StreamStreamMultiCallable | |
| ): | |
| def __call__( | |
| self, | |
| request_iterator: Optional[RequestIterableType] = None, | |
| timeout: Optional[float] = None, | |
| metadata: Optional[MetadataType] = None, | |
| credentials: Optional[grpc.CallCredentials] = None, | |
| wait_for_ready: Optional[bool] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| ) -> _base_call.StreamStreamCall: | |
| metadata = self._init_metadata(metadata, compression) | |
| if not self._interceptors: | |
| call = StreamStreamCall( | |
| request_iterator, | |
| _timeout_to_deadline(timeout), | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| else: | |
| call = InterceptedStreamStreamCall( | |
| self._interceptors, | |
| request_iterator, | |
| timeout, | |
| metadata, | |
| credentials, | |
| wait_for_ready, | |
| self._channel, | |
| self._method, | |
| self._request_serializer, | |
| self._response_deserializer, | |
| self._loop, | |
| ) | |
| return call | |
| class Channel(_base_channel.Channel): | |
| _loop: asyncio.AbstractEventLoop | |
| _channel: cygrpc.AioChannel | |
| _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] | |
| _unary_stream_interceptors: List[UnaryStreamClientInterceptor] | |
| _stream_unary_interceptors: List[StreamUnaryClientInterceptor] | |
| _stream_stream_interceptors: List[StreamStreamClientInterceptor] | |
| def __init__( | |
| self, | |
| target: str, | |
| options: ChannelArgumentType, | |
| credentials: Optional[grpc.ChannelCredentials], | |
| compression: Optional[grpc.Compression], | |
| interceptors: Optional[Sequence[ClientInterceptor]], | |
| ): | |
| """Constructor. | |
| Args: | |
| target: The target to which to connect. | |
| options: Configuration options for the channel. | |
| credentials: A cygrpc.ChannelCredentials or None. | |
| compression: An optional value indicating the compression method to be | |
| used over the lifetime of the channel. | |
| interceptors: An optional list of interceptors that would be used for | |
| intercepting any RPC executed with that channel. | |
| """ | |
| self._unary_unary_interceptors = [] | |
| self._unary_stream_interceptors = [] | |
| self._stream_unary_interceptors = [] | |
| self._stream_stream_interceptors = [] | |
| if interceptors is not None: | |
| for interceptor in interceptors: | |
| if isinstance(interceptor, UnaryUnaryClientInterceptor): | |
| self._unary_unary_interceptors.append(interceptor) | |
| elif isinstance(interceptor, UnaryStreamClientInterceptor): | |
| self._unary_stream_interceptors.append(interceptor) | |
| elif isinstance(interceptor, StreamUnaryClientInterceptor): | |
| self._stream_unary_interceptors.append(interceptor) | |
| elif isinstance(interceptor, StreamStreamClientInterceptor): | |
| self._stream_stream_interceptors.append(interceptor) | |
| else: | |
| raise ValueError( | |
| "Interceptor {} must be ".format(interceptor) | |
| + "{} or ".format(UnaryUnaryClientInterceptor.__name__) | |
| + "{} or ".format(UnaryStreamClientInterceptor.__name__) | |
| + "{} or ".format(StreamUnaryClientInterceptor.__name__) | |
| + "{}. ".format(StreamStreamClientInterceptor.__name__) | |
| ) | |
| self._loop = cygrpc.get_working_loop() | |
| self._channel = cygrpc.AioChannel( | |
| _common.encode(target), | |
| _augment_channel_arguments(options, compression), | |
| credentials, | |
| self._loop, | |
| ) | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| await self._close(None) | |
| async def _close(self, grace): # pylint: disable=too-many-branches | |
| if self._channel.closed(): | |
| return | |
| # No new calls will be accepted by the Cython channel. | |
| self._channel.closing() | |
| # Iterate through running tasks | |
| tasks = _all_tasks() | |
| calls = [] | |
| call_tasks = [] | |
| for task in tasks: | |
| try: | |
| stack = task.get_stack(limit=1) | |
| except AttributeError as attribute_error: | |
| # NOTE(lidiz) tl;dr: If the Task is created with a CPython | |
| # object, it will trigger AttributeError. | |
| # | |
| # In the global finalizer, the event loop schedules | |
| # a CPython PyAsyncGenAThrow object. | |
| # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 | |
| # | |
| # However, the PyAsyncGenAThrow object is written in C and | |
| # failed to include the normal Python frame objects. Hence, | |
| # this exception is a false negative, and it is safe to ignore | |
| # the failure. It is fixed by https://github.com/python/cpython/pull/18669, | |
| # but not available until 3.9 or 3.8.3. So, we have to keep it | |
| # for a while. | |
| # TODO(lidiz) drop this hack after 3.8 deprecation | |
| if "frame" in str(attribute_error): | |
| continue | |
| else: | |
| raise | |
| # If the Task is created by a C-extension, the stack will be empty. | |
| if not stack: | |
| continue | |
| # Locate ones created by `aio.Call`. | |
| frame = stack[0] | |
| candidate = frame.f_locals.get("self") | |
| if candidate: | |
| if isinstance(candidate, _base_call.Call): | |
| if hasattr(candidate, "_channel"): | |
| # For intercepted Call object | |
| if candidate._channel is not self._channel: | |
| continue | |
| elif hasattr(candidate, "_cython_call"): | |
| # For normal Call object | |
| if candidate._cython_call._channel is not self._channel: | |
| continue | |
| else: | |
| # Unidentified Call object | |
| raise cygrpc.InternalError( | |
| f"Unrecognized call object: {candidate}" | |
| ) | |
| calls.append(candidate) | |
| call_tasks.append(task) | |
| # If needed, try to wait for them to finish. | |
| # Call objects are not always awaitables. | |
| if grace and call_tasks: | |
| await asyncio.wait(call_tasks, timeout=grace) | |
| # Time to cancel existing calls. | |
| for call in calls: | |
| call.cancel() | |
| # Destroy the channel | |
| self._channel.close() | |
| async def close(self, grace: Optional[float] = None): | |
| await self._close(grace) | |
| def __del__(self): | |
| if hasattr(self, "_channel"): | |
| if not self._channel.closed(): | |
| self._channel.close() | |
| def get_state( | |
| self, try_to_connect: bool = False | |
| ) -> grpc.ChannelConnectivity: | |
| result = self._channel.check_connectivity_state(try_to_connect) | |
| return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] | |
| async def wait_for_state_change( | |
| self, | |
| last_observed_state: grpc.ChannelConnectivity, | |
| ) -> None: | |
| assert await self._channel.watch_connectivity_state( | |
| last_observed_state.value[0], None | |
| ) | |
| async def channel_ready(self) -> None: | |
| state = self.get_state(try_to_connect=True) | |
| while state != grpc.ChannelConnectivity.READY: | |
| await self.wait_for_state_change(state) | |
| state = self.get_state(try_to_connect=True) | |
| # TODO(xuanwn): Implement this method after we have | |
| # observability for Asyncio. | |
| def _get_registered_call_handle(self, method: str) -> int: | |
| pass | |
| # TODO(xuanwn): Implement _registered_method after we have | |
| # observability for Asyncio. | |
| # pylint: disable=arguments-differ,unused-argument | |
| def unary_unary( | |
| self, | |
| method: str, | |
| request_serializer: Optional[SerializingFunction] = None, | |
| response_deserializer: Optional[DeserializingFunction] = None, | |
| _registered_method: Optional[bool] = False, | |
| ) -> UnaryUnaryMultiCallable: | |
| return UnaryUnaryMultiCallable( | |
| self._channel, | |
| _common.encode(method), | |
| request_serializer, | |
| response_deserializer, | |
| self._unary_unary_interceptors, | |
| [self], | |
| self._loop, | |
| ) | |
| # TODO(xuanwn): Implement _registered_method after we have | |
| # observability for Asyncio. | |
| # pylint: disable=arguments-differ,unused-argument | |
| def unary_stream( | |
| self, | |
| method: str, | |
| request_serializer: Optional[SerializingFunction] = None, | |
| response_deserializer: Optional[DeserializingFunction] = None, | |
| _registered_method: Optional[bool] = False, | |
| ) -> UnaryStreamMultiCallable: | |
| return UnaryStreamMultiCallable( | |
| self._channel, | |
| _common.encode(method), | |
| request_serializer, | |
| response_deserializer, | |
| self._unary_stream_interceptors, | |
| [self], | |
| self._loop, | |
| ) | |
| # TODO(xuanwn): Implement _registered_method after we have | |
| # observability for Asyncio. | |
| # pylint: disable=arguments-differ,unused-argument | |
| def stream_unary( | |
| self, | |
| method: str, | |
| request_serializer: Optional[SerializingFunction] = None, | |
| response_deserializer: Optional[DeserializingFunction] = None, | |
| _registered_method: Optional[bool] = False, | |
| ) -> StreamUnaryMultiCallable: | |
| return StreamUnaryMultiCallable( | |
| self._channel, | |
| _common.encode(method), | |
| request_serializer, | |
| response_deserializer, | |
| self._stream_unary_interceptors, | |
| [self], | |
| self._loop, | |
| ) | |
| # TODO(xuanwn): Implement _registered_method after we have | |
| # observability for Asyncio. | |
| # pylint: disable=arguments-differ,unused-argument | |
| def stream_stream( | |
| self, | |
| method: str, | |
| request_serializer: Optional[SerializingFunction] = None, | |
| response_deserializer: Optional[DeserializingFunction] = None, | |
| _registered_method: Optional[bool] = False, | |
| ) -> StreamStreamMultiCallable: | |
| return StreamStreamMultiCallable( | |
| self._channel, | |
| _common.encode(method), | |
| request_serializer, | |
| response_deserializer, | |
| self._stream_stream_interceptors, | |
| [self], | |
| self._loop, | |
| ) | |
| def insecure_channel( | |
| target: str, | |
| options: Optional[ChannelArgumentType] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| interceptors: Optional[Sequence[ClientInterceptor]] = None, | |
| ): | |
| """Creates an insecure asynchronous Channel to a server. | |
| Args: | |
| target: The server address | |
| options: An optional list of key-value pairs (:term:`channel_arguments` | |
| in gRPC Core runtime) to configure the channel. | |
| compression: An optional value indicating the compression method to be | |
| used over the lifetime of the channel. | |
| interceptors: An optional sequence of interceptors that will be executed for | |
| any call executed with this channel. | |
| Returns: | |
| A Channel. | |
| """ | |
| return Channel( | |
| target, | |
| () if options is None else options, | |
| None, | |
| compression, | |
| interceptors, | |
| ) | |
| def secure_channel( | |
| target: str, | |
| credentials: grpc.ChannelCredentials, | |
| options: Optional[ChannelArgumentType] = None, | |
| compression: Optional[grpc.Compression] = None, | |
| interceptors: Optional[Sequence[ClientInterceptor]] = None, | |
| ): | |
| """Creates a secure asynchronous Channel to a server. | |
| Args: | |
| target: The server address. | |
| credentials: A ChannelCredentials instance. | |
| options: An optional list of key-value pairs (:term:`channel_arguments` | |
| in gRPC Core runtime) to configure the channel. | |
| compression: An optional value indicating the compression method to be | |
| used over the lifetime of the channel. | |
| interceptors: An optional sequence of interceptors that will be executed for | |
| any call executed with this channel. | |
| Returns: | |
| An aio.Channel. | |
| """ | |
| return Channel( | |
| target, | |
| () if options is None else options, | |
| credentials._credentials, | |
| compression, | |
| interceptors, | |
| ) | |