Spaces:
Sleeping
Sleeping
| # Copyright 2015 gRPC authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import collections | |
| import logging | |
| import threading | |
| from typing import Callable, Optional, Type | |
| import grpc | |
| from grpc import _common | |
| from grpc._cython import cygrpc | |
| from grpc._typing import MetadataType | |
| _LOGGER = logging.getLogger(__name__) | |
| class _AuthMetadataContext( | |
| collections.namedtuple( | |
| "AuthMetadataContext", | |
| ( | |
| "service_url", | |
| "method_name", | |
| ), | |
| ), | |
| grpc.AuthMetadataContext, | |
| ): | |
| pass | |
| class _CallbackState(object): | |
| def __init__(self): | |
| self.lock = threading.Lock() | |
| self.called = False | |
| self.exception = None | |
| class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): | |
| _state: _CallbackState | |
| _callback: Callable | |
| def __init__(self, state: _CallbackState, callback: Callable): | |
| self._state = state | |
| self._callback = callback | |
| def __call__( | |
| self, metadata: MetadataType, error: Optional[Type[BaseException]] | |
| ): | |
| with self._state.lock: | |
| if self._state.exception is None: | |
| if self._state.called: | |
| raise RuntimeError( | |
| "AuthMetadataPluginCallback invoked more than once!" | |
| ) | |
| else: | |
| self._state.called = True | |
| else: | |
| raise RuntimeError( | |
| 'AuthMetadataPluginCallback raised exception "{}"!'.format( | |
| self._state.exception | |
| ) | |
| ) | |
| if error is None: | |
| self._callback(metadata, cygrpc.StatusCode.ok, None) | |
| else: | |
| self._callback( | |
| None, cygrpc.StatusCode.internal, _common.encode(str(error)) | |
| ) | |
| class _Plugin(object): | |
| _metadata_plugin: grpc.AuthMetadataPlugin | |
| def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): | |
| self._metadata_plugin = metadata_plugin | |
| self._stored_ctx = None | |
| try: | |
| import contextvars # pylint: disable=wrong-import-position | |
| # The plugin may be invoked on a thread created by Core, which will not | |
| # have the context propagated. This context is stored and installed in | |
| # the thread invoking the plugin. | |
| self._stored_ctx = contextvars.copy_context() | |
| except ImportError: | |
| # Support versions predating contextvars. | |
| pass | |
| def __call__(self, service_url: str, method_name: str, callback: Callable): | |
| context = _AuthMetadataContext( | |
| _common.decode(service_url), _common.decode(method_name) | |
| ) | |
| callback_state = _CallbackState() | |
| try: | |
| self._metadata_plugin( | |
| context, _AuthMetadataPluginCallback(callback_state, callback) | |
| ) | |
| except Exception as exception: # pylint: disable=broad-except | |
| _LOGGER.exception( | |
| 'AuthMetadataPluginCallback "%s" raised exception!', | |
| self._metadata_plugin, | |
| ) | |
| with callback_state.lock: | |
| callback_state.exception = exception | |
| if callback_state.called: | |
| return | |
| callback( | |
| None, cygrpc.StatusCode.internal, _common.encode(str(exception)) | |
| ) | |
| def metadata_plugin_call_credentials( | |
| metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] | |
| ) -> grpc.CallCredentials: | |
| if name is None: | |
| try: | |
| effective_name = metadata_plugin.__name__ | |
| except AttributeError: | |
| effective_name = metadata_plugin.__class__.__name__ | |
| else: | |
| effective_name = name | |
| return grpc.CallCredentials( | |
| cygrpc.MetadataPluginCallCredentials( | |
| _Plugin(metadata_plugin), _common.encode(effective_name) | |
| ) | |
| ) | |