Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						a415c67
	
1
								Parent(s):
							
							ab063cf
								
minor integration adjustments
Browse files- middlewares.py +22 -24
- server.py +7 -9
- validators/__init__.py +2 -2
- validators/sn1_validator_wrapper.py +4 -5
    	
        middlewares.py
    CHANGED
    
    | @@ -1,34 +1,32 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import json
         | 
| 3 | 
             
            import bittensor as bt
         | 
| 4 | 
            -
            from aiohttp.web import Response
         | 
| 5 |  | 
| 6 | 
             
            EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
         | 
| 7 |  | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 |  | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 |  | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
                return middleware_handler
         | 
| 22 |  | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 |  | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
                return middleware_handler
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import json
         | 
| 3 | 
             
            import bittensor as bt
         | 
| 4 | 
            +
            from aiohttp.web import Request, Response, middleware
         | 
| 5 |  | 
| 6 | 
             
            EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
         | 
| 7 |  | 
| 8 | 
            +
            @middleware
         | 
| 9 | 
            +
            async def api_key_middleware(request: Request, handler):    
         | 
| 10 | 
            +
                # Logging the request
         | 
| 11 | 
            +
                bt.logging.info(f"Handling {request.method} request to {request.path}")
         | 
| 12 |  | 
| 13 | 
            +
                # Check access key
         | 
| 14 | 
            +
                access_key = request.headers.get("api_key")
         | 
| 15 | 
            +
                if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
         | 
| 16 | 
            +
                    bt.logging.error(f'Invalid access key: {access_key}')
         | 
| 17 | 
            +
                    return Response(status=401, reason="Invalid access key")
         | 
| 18 |  | 
| 19 | 
            +
                # Continue to the next handler if the API key is valid
         | 
| 20 | 
            +
                return await handler(request)    
         | 
|  | |
| 21 |  | 
| 22 | 
            +
            @middleware
         | 
| 23 | 
            +
            async def json_parsing_middleware(request: Request, handler):    
         | 
| 24 | 
            +
                try:
         | 
| 25 | 
            +
                    # Parsing JSON data from the request
         | 
| 26 | 
            +
                    request['data'] = await request.json()
         | 
| 27 | 
            +
                except json.JSONDecodeError as e:
         | 
| 28 | 
            +
                    bt.logging.error(f'Invalid JSON data: {str(e)}')
         | 
| 29 | 
            +
                    return Response(status=400, text="Invalid JSON")
         | 
| 30 |  | 
| 31 | 
            +
                # Continue to the next handler if JSON is successfully parsed
         | 
| 32 | 
            +
                return await handler(request)
         | 
|  | 
    	
        server.py
    CHANGED
    
    | @@ -34,8 +34,6 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n | |
| 34 | 
             
            ```
         | 
| 35 | 
             
            add --mock to test the echo stream
         | 
| 36 | 
             
            """
         | 
| 37 | 
            -
            @api_key_middleware
         | 
| 38 | 
            -
            @json_parsing_middleware
         | 
| 39 | 
             
            async def chat(request: web.Request) -> Response:
         | 
| 40 | 
             
                """
         | 
| 41 | 
             
                Chat endpoint for the validator.
         | 
| @@ -43,7 +41,7 @@ async def chat(request: web.Request) -> Response: | |
| 43 | 
             
                request_data = request['data']
         | 
| 44 | 
             
                params = QueryValidatorParams.from_dict(request_data)    
         | 
| 45 | 
             
                # TODO: SET STREAM AS DEFAULT
         | 
| 46 | 
            -
                stream = request_data.get('stream',  | 
| 47 |  | 
| 48 | 
             
                # Access the validator from the application context
         | 
| 49 | 
             
                validator: ValidatorAPI = request.app['validator']
         | 
| @@ -52,29 +50,29 @@ async def chat(request: web.Request) -> Response: | |
| 52 | 
             
                return response
         | 
| 53 |  | 
| 54 |  | 
| 55 | 
            -
            @api_key_middleware
         | 
| 56 | 
            -
            @json_parsing_middleware
         | 
| 57 | 
             
            async def echo_stream(request, request_data):    
         | 
| 58 | 
             
                request_data = request['data']
         | 
| 59 | 
             
                return await utils.echo_stream(request_data)
         | 
| 60 |  | 
| 61 |  | 
|  | |
| 62 | 
             
            class ValidatorApplication(web.Application):
         | 
| 63 | 
             
                def __init__(self, validator_instance=None, *args, **kwargs):
         | 
| 64 | 
             
                    super().__init__(*args, **kwargs)
         | 
| 65 |  | 
| 66 | 
             
                    self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
         | 
| 67 |  | 
| 68 | 
            -
                    # Add middlewares to application
         | 
| 69 | 
            -
                    self.middlewares.append(api_key_middleware)
         | 
| 70 | 
            -
                    self.middlewares.append(json_parsing_middleware)
         | 
| 71 | 
            -
                    
         | 
| 72 | 
             
                    self.add_routes([
         | 
| 73 | 
             
                        web.post('/chat/', chat),
         | 
| 74 | 
             
                        web.post('/echo/', echo_stream)
         | 
| 75 | 
             
                    ])
         | 
|  | |
| 76 | 
             
                    # TODO: Enable rewarding and other features
         | 
| 77 |  | 
|  | |
|  | |
|  | |
| 78 |  | 
| 79 | 
             
            def main(run_aio_app=True, test=False) -> None:
         | 
| 80 | 
             
                loop = asyncio.get_event_loop()
         | 
|  | |
| 34 | 
             
            ```
         | 
| 35 | 
             
            add --mock to test the echo stream
         | 
| 36 | 
             
            """
         | 
|  | |
|  | |
| 37 | 
             
            async def chat(request: web.Request) -> Response:
         | 
| 38 | 
             
                """
         | 
| 39 | 
             
                Chat endpoint for the validator.
         | 
|  | |
| 41 | 
             
                request_data = request['data']
         | 
| 42 | 
             
                params = QueryValidatorParams.from_dict(request_data)    
         | 
| 43 | 
             
                # TODO: SET STREAM AS DEFAULT
         | 
| 44 | 
            +
                stream = request_data.get('stream', True)        
         | 
| 45 |  | 
| 46 | 
             
                # Access the validator from the application context
         | 
| 47 | 
             
                validator: ValidatorAPI = request.app['validator']
         | 
|  | |
| 50 | 
             
                return response
         | 
| 51 |  | 
| 52 |  | 
|  | |
|  | |
| 53 | 
             
            async def echo_stream(request, request_data):    
         | 
| 54 | 
             
                request_data = request['data']
         | 
| 55 | 
             
                return await utils.echo_stream(request_data)
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
            +
             | 
| 59 | 
             
            class ValidatorApplication(web.Application):
         | 
| 60 | 
             
                def __init__(self, validator_instance=None, *args, **kwargs):
         | 
| 61 | 
             
                    super().__init__(*args, **kwargs)
         | 
| 62 |  | 
| 63 | 
             
                    self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
         | 
| 64 |  | 
| 65 | 
            +
                    # Add middlewares to application                
         | 
|  | |
|  | |
|  | |
| 66 | 
             
                    self.add_routes([
         | 
| 67 | 
             
                        web.post('/chat/', chat),
         | 
| 68 | 
             
                        web.post('/echo/', echo_stream)
         | 
| 69 | 
             
                    ])
         | 
| 70 | 
            +
                    self.setup_middlewares()
         | 
| 71 | 
             
                    # TODO: Enable rewarding and other features
         | 
| 72 |  | 
| 73 | 
            +
                def setup_middlewares(self):
         | 
| 74 | 
            +
                    self.middlewares.append(json_parsing_middleware)
         | 
| 75 | 
            +
                    self.middlewares.append(api_key_middleware)
         | 
| 76 |  | 
| 77 | 
             
            def main(run_aio_app=True, test=False) -> None:
         | 
| 78 | 
             
                loop = asyncio.get_event_loop()
         | 
    	
        validators/__init__.py
    CHANGED
    
    | @@ -1,2 +1,2 @@ | |
| 1 | 
            -
            from base import QueryValidatorParams, ValidatorAPI, MockValidator
         | 
| 2 | 
            -
            from sn1_validator_wrapper import S1ValidatorAPI
         | 
|  | |
| 1 | 
            +
            from .base import QueryValidatorParams, ValidatorAPI, MockValidator
         | 
| 2 | 
            +
            from .sn1_validator_wrapper import S1ValidatorAPI
         | 
    	
        validators/sn1_validator_wrapper.py
    CHANGED
    
    | @@ -1,14 +1,13 @@ | |
| 1 | 
             
            import json
         | 
| 2 | 
             
            import utils
         | 
|  | |
| 3 | 
             
            import traceback
         | 
| 4 | 
             
            import bittensor as bt
         | 
| 5 | 
            -
            import asyncio
         | 
| 6 | 
            -
            from prompting.forward import handle_response
         | 
| 7 | 
             
            from prompting.validator import Validator
         | 
| 8 | 
             
            from prompting.utils.uids import get_random_uids
         | 
| 9 | 
             
            from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
         | 
| 10 | 
             
            from prompting.dendrite import DendriteResponseEvent
         | 
| 11 | 
            -
            from base import QueryValidatorParams, ValidatorAPI
         | 
| 12 | 
             
            from aiohttp.web_response import Response, StreamResponse
         | 
| 13 | 
             
            from deprecated import deprecated
         | 
| 14 |  | 
| @@ -16,7 +15,7 @@ class S1ValidatorAPI(ValidatorAPI): | |
| 16 | 
             
                def __init__(self):
         | 
| 17 | 
             
                    self.validator = Validator()    
         | 
| 18 |  | 
| 19 | 
            -
             | 
| 20 | 
             
                @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
         | 
| 21 | 
             
                async def get_response(self, params:QueryValidatorParams) -> Response:
         | 
| 22 | 
             
                    try:
         | 
| @@ -37,7 +36,7 @@ class S1ValidatorAPI(ValidatorAPI): | |
| 37 |  | 
| 38 | 
             
                        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
         | 
| 39 | 
             
                        # Encapsulate the responses in a response event (dataclass)
         | 
| 40 | 
            -
                        response_event = DendriteResponseEvent(responses, uids)
         | 
| 41 |  | 
| 42 | 
             
                        # convert dict to json
         | 
| 43 | 
             
                        response = response_event.__state_dict__()
         | 
|  | |
| 1 | 
             
            import json
         | 
| 2 | 
             
            import utils
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
             
            import traceback
         | 
| 5 | 
             
            import bittensor as bt
         | 
|  | |
|  | |
| 6 | 
             
            from prompting.validator import Validator
         | 
| 7 | 
             
            from prompting.utils.uids import get_random_uids
         | 
| 8 | 
             
            from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
         | 
| 9 | 
             
            from prompting.dendrite import DendriteResponseEvent
         | 
| 10 | 
            +
            from .base import QueryValidatorParams, ValidatorAPI
         | 
| 11 | 
             
            from aiohttp.web_response import Response, StreamResponse
         | 
| 12 | 
             
            from deprecated import deprecated
         | 
| 13 |  | 
|  | |
| 15 | 
             
                def __init__(self):
         | 
| 16 | 
             
                    self.validator = Validator()    
         | 
| 17 |  | 
| 18 | 
            +
             | 
| 19 | 
             
                @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
         | 
| 20 | 
             
                async def get_response(self, params:QueryValidatorParams) -> Response:
         | 
| 21 | 
             
                    try:
         | 
|  | |
| 36 |  | 
| 37 | 
             
                        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
         | 
| 38 | 
             
                        # Encapsulate the responses in a response event (dataclass)
         | 
| 39 | 
            +
                        response_event = DendriteResponseEvent(responses, torch.LongTensor(uids), params.timeout)
         | 
| 40 |  | 
| 41 | 
             
                        # convert dict to json
         | 
| 42 | 
             
                        response = response_event.__state_dict__()
         | 
