Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						2a5b08d
	
1
								Parent(s):
							
							3b4a72e
								
initial streamer implementation
Browse files- utils.py +20 -41
- validators/__init__.py +1 -0
- validators/sn1_validator_wrapper.py +29 -156
- validators/streamer.py +121 -0
    	
        utils.py
    CHANGED
    
    | @@ -1,12 +1,10 @@ | |
| 1 | 
             
            import re
         | 
| 2 | 
            -
            import time
         | 
| 3 | 
            -
            import json
         | 
| 4 | 
             
            import asyncio
         | 
| 5 | 
             
            import bittensor as bt
         | 
| 6 | 
             
            from aiohttp import web
         | 
| 7 | 
            -
            from responses import TextStreamResponse
         | 
| 8 | 
             
            from collections import Counter
         | 
| 9 | 
             
            from prompting.rewards import DateRewardModel, FloatDiffModel
         | 
|  | |
| 10 |  | 
| 11 | 
             
            UNSUCCESSFUL_RESPONSE_PATTERNS = [
         | 
| 12 | 
             
                "I'm sorry",
         | 
| @@ -136,46 +134,27 @@ def guess_task_name(challenge: str): | |
| 136 | 
             
                return "qa"
         | 
| 137 |  | 
| 138 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 139 | 
             
            async def echo_stream(request: web.Request) -> web.StreamResponse:
         | 
| 140 | 
             
                request_data = request["data"]
         | 
| 141 | 
             
                k = request_data.get("k", 1)
         | 
| 142 | 
             
                message = "\n\n".join(request_data["messages"])
         | 
| 143 |  | 
| 144 | 
            -
             | 
| 145 | 
            -
                 | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
                await  | 
| 149 | 
            -
             | 
| 150 | 
            -
                completion = ""
         | 
| 151 | 
            -
                chunks = []
         | 
| 152 | 
            -
                chunks_timings = []
         | 
| 153 | 
            -
                start_time = time.time()
         | 
| 154 | 
            -
                # Echo the message k times with a timeout between each chunk
         | 
| 155 | 
            -
                for _ in range(k):
         | 
| 156 | 
            -
                    for word in message.split():
         | 
| 157 | 
            -
                        chunk = f"{word} "
         | 
| 158 | 
            -
                        await response.write(chunk.encode("utf-8"))
         | 
| 159 | 
            -
                        completion += chunk
         | 
| 160 | 
            -
                        await asyncio.sleep(0.3)
         | 
| 161 | 
            -
                        bt.logging.info(f"Echoed: {chunk}")
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                        chunks.append(chunk)
         | 
| 164 | 
            -
                        chunks_timings.append(time.time() - start_time)
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                completion = completion.strip()
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                # Prepare final JSON chunk
         | 
| 169 | 
            -
                response_data = TextStreamResponse(
         | 
| 170 | 
            -
                    streamed_chunks=chunks,
         | 
| 171 | 
            -
                    streamed_chunks_timings=chunks_timings,
         | 
| 172 | 
            -
                    completion=completion,
         | 
| 173 | 
            -
                    timing=time.time() - start_time,
         | 
| 174 | 
            -
                ).to_dict()
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                # Send the final JSON as part of the stream
         | 
| 177 | 
            -
                await response.write(json.dumps(response_data).encode("utf-8"))
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                # Finalize the response
         | 
| 180 | 
            -
                await response.write_eof()
         | 
| 181 | 
            -
                return response
         | 
|  | |
| 1 | 
             
            import re
         | 
|  | |
|  | |
| 2 | 
             
            import asyncio
         | 
| 3 | 
             
            import bittensor as bt
         | 
| 4 | 
             
            from aiohttp import web
         | 
|  | |
| 5 | 
             
            from collections import Counter
         | 
| 6 | 
             
            from prompting.rewards import DateRewardModel, FloatDiffModel
         | 
| 7 | 
            +
            from validators.streamer import AsyncResponseDataStreamer
         | 
| 8 |  | 
| 9 | 
             
            UNSUCCESSFUL_RESPONSE_PATTERNS = [
         | 
| 10 | 
             
                "I'm sorry",
         | 
|  | |
| 134 | 
             
                return "qa"
         | 
| 135 |  | 
| 136 |  | 
| 137 | 
            +
            # Simulate the stream synapse for the echo endpoint
         | 
| 138 | 
            +
            class EchoAsyncIterator:
         | 
| 139 | 
            +
                def __init__(self, message: str, k: int, delay: float):
         | 
| 140 | 
            +
                    self.message = message
         | 
| 141 | 
            +
                    self.k = k
         | 
| 142 | 
            +
                    self.delay = delay
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                async def __aiter__(self):
         | 
| 145 | 
            +
                    for _ in range(self.k):
         | 
| 146 | 
            +
                        for word in self.message.split():
         | 
| 147 | 
            +
                            yield [word]
         | 
| 148 | 
            +
                            await asyncio.sleep(self.delay)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
             
            async def echo_stream(request: web.Request) -> web.StreamResponse:
         | 
| 152 | 
             
                request_data = request["data"]
         | 
| 153 | 
             
                k = request_data.get("k", 1)
         | 
| 154 | 
             
                message = "\n\n".join(request_data["messages"])
         | 
| 155 |  | 
| 156 | 
            +
             | 
| 157 | 
            +
                echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
         | 
| 158 | 
            +
                streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                return await streamer.stream(request)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        validators/__init__.py
    CHANGED
    
    | @@ -1,2 +1,3 @@ | |
| 1 | 
             
            from .base import QueryValidatorParams, ValidatorAPI, MockValidator
         | 
| 2 | 
             
            from .sn1_validator_wrapper import S1ValidatorAPI
         | 
|  | 
|  | |
| 1 | 
             
            from .base import QueryValidatorParams, ValidatorAPI, MockValidator
         | 
| 2 | 
             
            from .sn1_validator_wrapper import S1ValidatorAPI
         | 
| 3 | 
            +
            from .streamer import AsyncResponseDataStreamer
         | 
    	
        validators/sn1_validator_wrapper.py
    CHANGED
    
    | @@ -1,21 +1,13 @@ | |
| 1 | 
            -
            import json
         | 
| 2 | 
            -
            import utils
         | 
| 3 | 
            -
            import torch
         | 
| 4 | 
            -
            import traceback
         | 
| 5 | 
            -
            import time
         | 
| 6 | 
             
            import random
         | 
| 7 | 
             
            import bittensor as bt
         | 
| 8 | 
            -
            from typing import Awaitable
         | 
| 9 | 
             
            from prompting.validator import Validator
         | 
| 10 | 
             
            from prompting.utils.uids import get_random_uids
         | 
| 11 | 
            -
            from prompting.protocol import  | 
| 12 | 
            -
            from prompting.dendrite import DendriteResponseEvent
         | 
| 13 | 
             
            from .base import QueryValidatorParams, ValidatorAPI
         | 
| 14 | 
             
            from aiohttp.web_response import Response, StreamResponse
         | 
| 15 | 
            -
            from deprecated import deprecated
         | 
| 16 | 
             
            from dataclasses import dataclass
         | 
| 17 | 
             
            from typing import List
         | 
| 18 | 
            -
            from  | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            @dataclass
         | 
| @@ -29,154 +21,35 @@ class S1ValidatorAPI(ValidatorAPI): | |
| 29 | 
             
                def __init__(self):
         | 
| 30 | 
             
                    self.validator = Validator()
         | 
| 31 |  | 
| 32 | 
            -
                @deprecated(
         | 
| 33 | 
            -
                    reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead."
         | 
| 34 | 
            -
                )
         | 
| 35 | 
            -
                async def get_response(self, params: QueryValidatorParams) -> Response:
         | 
| 36 | 
            -
                    try:
         | 
| 37 | 
            -
                        # Guess the task name of current request
         | 
| 38 | 
            -
                        task_name = utils.guess_task_name(params.messages[-1])
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                        # Get the list of uids to query for this step.
         | 
| 41 | 
            -
                        uids = get_random_uids(
         | 
| 42 | 
            -
                            self.validator, k=params.k_miners, exclude=params.exclude or []
         | 
| 43 | 
            -
                        ).tolist()
         | 
| 44 | 
            -
                        axons = [self.validator.metagraph.axons[uid] for uid in uids]
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                        # Make calls to the network with the prompt.
         | 
| 47 | 
            -
                        bt.logging.info(f"Calling dendrite")
         | 
| 48 | 
            -
                        responses = await self.validator.dendrite(
         | 
| 49 | 
            -
                            axons=axons,
         | 
| 50 | 
            -
                            synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
         | 
| 51 | 
            -
                            timeout=params.timeout,
         | 
| 52 | 
            -
                        )
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
         | 
| 55 | 
            -
                        # Encapsulate the responses in a response event (dataclass)
         | 
| 56 | 
            -
                        response_event = DendriteResponseEvent(
         | 
| 57 | 
            -
                            responses, torch.LongTensor(uids), params.timeout
         | 
| 58 | 
            -
                        )
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                        # convert dict to json
         | 
| 61 | 
            -
                        response = response_event.__state_dict__()
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                        response["completion_is_valid"] = valid = list(
         | 
| 64 | 
            -
                            map(utils.completion_is_valid, response["completions"])
         | 
| 65 | 
            -
                        )
         | 
| 66 | 
            -
                        valid_completions = [
         | 
| 67 | 
            -
                            response["completions"][i] for i, v in enumerate(valid) if v
         | 
| 68 | 
            -
                        ]
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                        response["task_name"] = task_name
         | 
| 71 | 
            -
                        response["ensemble_result"] = utils.ensemble_result(
         | 
| 72 | 
            -
                            valid_completions, task_name=task_name, prefer=params.prefer
         | 
| 73 | 
            -
                        )
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                        bt.logging.info(f"Response:\n {response}")
         | 
| 76 | 
            -
                        return Response(
         | 
| 77 | 
            -
                            status=200,
         | 
| 78 | 
            -
                            reason="I can't believe it's not butter!",
         | 
| 79 | 
            -
                            text=json.dumps(response),
         | 
| 80 | 
            -
                        )
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    except Exception:
         | 
| 83 | 
            -
                        bt.logging.error(
         | 
| 84 | 
            -
                            f"Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}"
         | 
| 85 | 
            -
                        )
         | 
| 86 | 
            -
                        return Response(status=500, reason="Internal error")
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                async def process_response(
         | 
| 89 | 
            -
                    self, response: StreamResponse, async_generator: Awaitable
         | 
| 90 | 
            -
                ) -> ProcessedStreamResponse:
         | 
| 91 | 
            -
                    """Process a single response asynchronously."""
         | 
| 92 | 
            -
                    # Initialize chunk with a default value
         | 
| 93 | 
            -
                    chunk = None
         | 
| 94 | 
            -
                    # Initialize chunk array to accumulate streamed chunks
         | 
| 95 | 
            -
                    chunks = []
         | 
| 96 | 
            -
                    chunks_timings = []
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                    start_time = time.time()
         | 
| 99 | 
            -
                    last_sent_index = 0
         | 
| 100 | 
            -
                    async for chunk in async_generator:
         | 
| 101 | 
            -
                        if isinstance(chunk, list):
         | 
| 102 | 
            -
                            # Chunks are currently returned in string arrays, so we need to concatenate them
         | 
| 103 | 
            -
                            concatenated_chunks = "".join(chunk)
         | 
| 104 | 
            -
                            new_data = concatenated_chunks[last_sent_index:]
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                            if new_data:
         | 
| 107 | 
            -
                                await response.write(new_data.encode("utf-8"))
         | 
| 108 | 
            -
                                bt.logging.info(f"Received new chunk from miner: {chunk}")
         | 
| 109 | 
            -
                                last_sent_index += len(new_data)
         | 
| 110 | 
            -
                                chunks.extend(chunk)
         | 
| 111 | 
            -
                                chunks_timings.append(time.time() - start_time)
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    if chunk is not None and isinstance(chunk, StreamPromptingSynapse):
         | 
| 114 | 
            -
                        # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
         | 
| 115 | 
            -
                        return ProcessedStreamResponse(
         | 
| 116 | 
            -
                            synapse=chunk,
         | 
| 117 | 
            -
                            streamed_chunks=chunks,
         | 
| 118 | 
            -
                            streamed_chunks_timings=chunks_timings,
         | 
| 119 | 
            -
                        )
         | 
| 120 | 
            -
                    else:
         | 
| 121 | 
            -
                        raise ValueError("The last chunks not a StreamPrompting synapse")
         | 
| 122 |  | 
| 123 | 
             
                async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
         | 
| 124 | 
            -
                     | 
| 125 | 
            -
                     | 
| 126 | 
            -
             | 
| 127 | 
            -
                     | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
                        axons | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
                         | 
| 141 | 
            -
                         | 
| 142 | 
            -
             | 
| 143 | 
            -
                         | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
                        )
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                        uid_stream_dict = dict(zip(uids, streams_responses))
         | 
| 154 | 
            -
             | 
| 155 | 
            -
                        random_uid, random_stream = random.choice(list(uid_stream_dict.items()))
         | 
| 156 | 
            -
                        processed_response = await self.process_response(response, random_stream)
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                        # Prepare final JSON chunk
         | 
| 159 | 
            -
                        response_data = json.dumps(
         | 
| 160 | 
            -
                            TextStreamResponse(
         | 
| 161 | 
            -
                                streamed_chunks=processed_response.streamed_chunks,
         | 
| 162 | 
            -
                                streamed_chunks_timings=processed_response.streamed_chunks_timings,
         | 
| 163 | 
            -
                                uid=random_uid,
         | 
| 164 | 
            -
                                completion=processed_response.synapse.completion,
         | 
| 165 | 
            -
                                timing=time.time() - start_time,
         | 
| 166 | 
            -
                            ).to_dict()
         | 
| 167 | 
            -
                        )
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                        # Send the final JSON as part of the stream
         | 
| 170 | 
            -
                        await response.write(json.dumps(response_data).encode("utf-8"))
         | 
| 171 | 
            -
                    except Exception as e:
         | 
| 172 | 
            -
                        bt.logging.error(
         | 
| 173 | 
            -
                            f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
         | 
| 174 | 
            -
                        )
         | 
| 175 | 
            -
                        response.set_status(500, reason="Internal error")
         | 
| 176 | 
            -
                        await response.write(json.dumps({"error": str(e)}).encode("utf-8"))
         | 
| 177 | 
            -
                    finally:
         | 
| 178 | 
            -
                        await response.write_eof()  # Ensure to close the response properly
         | 
| 179 | 
            -
             | 
| 180 | 
             
                    return response
         | 
| 181 |  | 
| 182 | 
             
                async def query_validator(self, params: QueryValidatorParams) -> Response:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import random
         | 
| 2 | 
             
            import bittensor as bt
         | 
|  | |
| 3 | 
             
            from prompting.validator import Validator
         | 
| 4 | 
             
            from prompting.utils.uids import get_random_uids
         | 
| 5 | 
            +
            from prompting.protocol import StreamPromptingSynapse
         | 
|  | |
| 6 | 
             
            from .base import QueryValidatorParams, ValidatorAPI
         | 
| 7 | 
             
            from aiohttp.web_response import Response, StreamResponse
         | 
|  | |
| 8 | 
             
            from dataclasses import dataclass
         | 
| 9 | 
             
            from typing import List
         | 
| 10 | 
            +
            from .streamer import AsyncResponseDataStreamer
         | 
| 11 |  | 
| 12 |  | 
| 13 | 
             
            @dataclass
         | 
|  | |
| 21 | 
             
                def __init__(self):
         | 
| 22 | 
             
                    self.validator = Validator()
         | 
| 23 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 |  | 
| 25 | 
             
                async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
         | 
| 26 | 
            +
                    # Guess the task name of current request
         | 
| 27 | 
            +
                    # task_name = utils.guess_task_name(params.messages[-1])
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    # Get the list of uids to query for this step.
         | 
| 30 | 
            +
                    uids = get_random_uids(
         | 
| 31 | 
            +
                        self.validator, k=params.k_miners, exclude=params.exclude or []
         | 
| 32 | 
            +
                    ).tolist()
         | 
| 33 | 
            +
                    axons = [self.validator.metagraph.axons[uid] for uid in uids]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    # Make calls to the network with the prompt.
         | 
| 36 | 
            +
                    bt.logging.info(f"Calling dendrite")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    streams_responses = await self.validator.dendrite(
         | 
| 39 | 
            +
                        axons=axons,
         | 
| 40 | 
            +
                        synapse=StreamPromptingSynapse(
         | 
| 41 | 
            +
                            roles=params.roles, messages=params.messages
         | 
| 42 | 
            +
                        ),
         | 
| 43 | 
            +
                        timeout=params.timeout,
         | 
| 44 | 
            +
                        deserialize=False,
         | 
| 45 | 
            +
                        streaming=True,
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
                    uid_stream_dict = dict(zip(uids, streams_responses))
         | 
| 48 | 
            +
                    random_uid, random_stream = random.choice(list(uid_stream_dict.items()))                        
         | 
| 49 | 
            +
                    
         | 
| 50 | 
            +
                    # Creates a streamer from the selected stream
         | 
| 51 | 
            +
                    streamer = AsyncResponseDataStreamer(async_iterator=random_stream, selected_uid=random_uid)        
         | 
| 52 | 
            +
                    response = await streamer.stream(params.request)                        
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 53 | 
             
                    return response
         | 
| 54 |  | 
| 55 | 
             
                async def query_validator(self, params: QueryValidatorParams) -> Response:
         | 
    	
        validators/streamer.py
    ADDED
    
    | @@ -0,0 +1,121 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import traceback
         | 
| 4 | 
            +
            import bittensor as bt
         | 
| 5 | 
            +
            from pydantic import BaseModel
         | 
| 6 | 
            +
            from datetime import datetime
         | 
| 7 | 
            +
            from typing import AsyncIterator, Optional, List
         | 
| 8 | 
            +
            from aiohttp import web, web_response
         | 
| 9 | 
            +
            from prompting.protocol import StreamPromptingSynapse
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class StreamChunk(BaseModel):
         | 
| 13 | 
            +
                delta: str
         | 
| 14 | 
            +
                finish_reason: Optional[str]
         | 
| 15 | 
            +
                accumulated_chunks: List[str]
         | 
| 16 | 
            +
                accumulated_chunks_timings: List[float]
         | 
| 17 | 
            +
                timestamp: str
         | 
| 18 | 
            +
                sequence_number: int
         | 
| 19 | 
            +
                selected_uid: int
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                def encode(self, encoding: str) -> bytes:
         | 
| 22 | 
            +
                    data = json.dumps(self.dict(), indent=4)
         | 
| 23 | 
            +
                    return data.encode(encoding)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class StreamError(BaseModel):
         | 
| 27 | 
            +
                error: str
         | 
| 28 | 
            +
                timestamp: str
         | 
| 29 | 
            +
                sequence_number: int
         | 
| 30 | 
            +
                finish_reason: str = 'error'
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def encode(self, encoding: str) -> bytes:
         | 
| 33 | 
            +
                    data = json.dumps(self.dict(), indent=4)
         | 
| 34 | 
            +
                    return data.encode(encoding)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class AsyncResponseDataStreamer:
         | 
| 38 | 
            +
                def __init__(self, async_iterator: AsyncIterator, selected_uid:int, delay: float = 0.1):
         | 
| 39 | 
            +
                    self.async_iterator = async_iterator        
         | 
| 40 | 
            +
                    self.delay = delay
         | 
| 41 | 
            +
                    self.selected_uid = selected_uid
         | 
| 42 | 
            +
                    self.accumulated_chunks: List[str] = []
         | 
| 43 | 
            +
                    self.accumulated_chunks_timings: List[float] = []
         | 
| 44 | 
            +
                    self.finish_reason: str = None
         | 
| 45 | 
            +
                    self.sequence_number: int = 0
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                async def stream(self, request: web.Request) -> web_response.StreamResponse:    
         | 
| 48 | 
            +
                    response = web_response.StreamResponse(status=200, reason="OK")
         | 
| 49 | 
            +
                    response.headers["Content-Type"] = "application/json"
         | 
| 50 | 
            +
                    await response.prepare(request)  # Prepare and send the headers
         | 
| 51 | 
            +
                                    
         | 
| 52 | 
            +
                    try:
         | 
| 53 | 
            +
                        start_time = time.time()
         | 
| 54 | 
            +
                        async for chunk in self.async_iterator:
         | 
| 55 | 
            +
                             if isinstance(chunk, list):
         | 
| 56 | 
            +
                                # Chunks are currently returned in string arrays, so we need to concatenate them
         | 
| 57 | 
            +
                                concatenated_chunks = "".join(chunk)
         | 
| 58 | 
            +
                                self.accumulated_chunks.append(concatenated_chunks)                    
         | 
| 59 | 
            +
                                self.accumulated_chunks_timings.append(time.time() - start_time)
         | 
| 60 | 
            +
                                # Gets new response state
         | 
| 61 | 
            +
                                self.sequence_number += 1
         | 
| 62 | 
            +
                                new_response_state = self._create_chunk_response(concatenated_chunks)                    
         | 
| 63 | 
            +
                                # Writes the new response state to the response                                    
         | 
| 64 | 
            +
                                await response.write(new_response_state.encode('utf-8'))
         | 
| 65 | 
            +
                                            
         | 
| 66 | 
            +
                        if chunk is not None and isinstance(chunk, StreamPromptingSynapse):                                                            
         | 
| 67 | 
            +
                            self.finish_reason = "completed"
         | 
| 68 | 
            +
                            self.sequence_number += 1
         | 
| 69 | 
            +
                            # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
         | 
| 70 | 
            +
                            synapse = chunk 
         | 
| 71 | 
            +
                            final_state = self._create_chunk_response(synapse.completion)
         | 
| 72 | 
            +
                            await response.write(final_state.encode('utf-8'))
         | 
| 73 | 
            +
                                                    
         | 
| 74 | 
            +
                    except Exception as e:
         | 
| 75 | 
            +
                        bt.logging.error(
         | 
| 76 | 
            +
                            f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
                        response.set_status(500, reason="Internal error")
         | 
| 79 | 
            +
                        error_response = self._create_error_response(str(e))
         | 
| 80 | 
            +
                        response.write(error_response.encode('utf-8'))
         | 
| 81 | 
            +
                    finally:
         | 
| 82 | 
            +
                        await response.write_eof()  # Ensure to close the response properly
         | 
| 83 | 
            +
                        return response
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def _create_chunk_response(self, chunk: str) -> StreamChunk:
         | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    Creates a StreamChunk object with the current state.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    :param chunk: List of strings representing the current chunk.
         | 
| 90 | 
            +
                    :return: StreamChunk object.
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    return StreamChunk(
         | 
| 93 | 
            +
                        delta=chunk,
         | 
| 94 | 
            +
                        finish_reason=self.finish_reason,
         | 
| 95 | 
            +
                        accumulated_chunks=self.accumulated_chunks,
         | 
| 96 | 
            +
                        accumulated_chunks_timings=self.accumulated_chunks_timings,
         | 
| 97 | 
            +
                        timestamp=self._current_timestamp(),
         | 
| 98 | 
            +
                        sequence_number=self.sequence_number,
         | 
| 99 | 
            +
                        selected_uid=self.selected_uid
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def _create_error_response(self, error_message: str) -> StreamError:
         | 
| 103 | 
            +
                    """
         | 
| 104 | 
            +
                    Creates a StreamError object.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    :param error_message: Error message to include in the StreamError.
         | 
| 107 | 
            +
                    :return: StreamError object.
         | 
| 108 | 
            +
                    """
         | 
| 109 | 
            +
                    return StreamError(
         | 
| 110 | 
            +
                        error=error_message,
         | 
| 111 | 
            +
                        timestamp=self._current_timestamp(),
         | 
| 112 | 
            +
                        sequence_number=self.sequence_number
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def _current_timestamp(self) -> str:
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    Returns the current timestamp in ISO format.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    :return: Current timestamp as a string.
         | 
| 120 | 
            +
                    """
         | 
| 121 | 
            +
                    return datetime.utcnow().isoformat()
         | 
