Spaces:
Runtime error
Runtime error
create a pipeline class with basic test coverage
Browse files- pipeline_test.py → pipeline.py +67 -79
- tests/test_pipeline.py +83 -0
pipeline_test.py → pipeline.py
RENAMED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
import random
|
| 3 |
-
import time
|
| 4 |
-
|
| 5 |
|
| 6 |
class Job:
|
| 7 |
-
def __init__(self,
|
| 8 |
-
self.
|
| 9 |
self.data = data
|
| 10 |
|
| 11 |
|
|
@@ -31,6 +28,11 @@ class Node:
|
|
| 31 |
self._jobs_dequeued += 1
|
| 32 |
if self.sequential_node == False:
|
| 33 |
await self.process_job(job)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
else:
|
| 35 |
# ensure that jobs are processed in order
|
| 36 |
self.buffer[job.id] = job
|
|
@@ -38,79 +40,65 @@ class Node:
|
|
| 38 |
job = self.buffer.pop(self.next_i)
|
| 39 |
await self.process_job(job)
|
| 40 |
self.next_i += 1
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
async def process_job(self, job: Job):
|
| 48 |
-
raise NotImplementedError
|
| 49 |
|
| 50 |
-
|
| 51 |
-
class Node1(Node):
|
| 52 |
async def process_job(self, job: Job):
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
start_time = time.time()
|
| 109 |
-
|
| 110 |
-
try:
|
| 111 |
-
asyncio.run(main())
|
| 112 |
-
except KeyboardInterrupt:
|
| 113 |
-
print("Pipeline interrupted by user")
|
| 114 |
-
|
| 115 |
-
end_time = time.time()
|
| 116 |
-
print(f"Pipeline processed in {end_time - start_time} seconds.")
|
|
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class Job:
|
| 4 |
+
def __init__(self, data):
|
| 5 |
+
self._id = None
|
| 6 |
self.data = data
|
| 7 |
|
| 8 |
|
|
|
|
| 28 |
self._jobs_dequeued += 1
|
| 29 |
if self.sequential_node == False:
|
| 30 |
await self.process_job(job)
|
| 31 |
+
if self.output_queue is not None:
|
| 32 |
+
await self.output_queue.put(job)
|
| 33 |
+
if self.job_sync is not None:
|
| 34 |
+
self.job_sync.append(job)
|
| 35 |
+
self._jobs_processed += 1
|
| 36 |
else:
|
| 37 |
# ensure that jobs are processed in order
|
| 38 |
self.buffer[job.id] = job
|
|
|
|
| 40 |
job = self.buffer.pop(self.next_i)
|
| 41 |
await self.process_job(job)
|
| 42 |
self.next_i += 1
|
| 43 |
+
if self.output_queue is not None:
|
| 44 |
+
await self.output_queue.put(job)
|
| 45 |
+
if self.job_sync is not None:
|
| 46 |
+
self.job_sync.append(job)
|
| 47 |
+
self._jobs_processed += 1
|
|
|
|
|
|
|
|
|
|
| 48 |
|
|
|
|
|
|
|
| 49 |
async def process_job(self, job: Job):
|
| 50 |
+
raise NotImplementedError()
|
| 51 |
+
|
| 52 |
+
class Pipeline:
|
| 53 |
+
def __init__(self):
|
| 54 |
+
self.input_queues = []
|
| 55 |
+
self.root_queue = None
|
| 56 |
+
# self.output_queues = []
|
| 57 |
+
# self.job_sysncs = []
|
| 58 |
+
self.nodes= []
|
| 59 |
+
self.node_workers = {}
|
| 60 |
+
self.tasks = []
|
| 61 |
+
self._job_id = 0
|
| 62 |
+
|
| 63 |
+
async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
|
| 64 |
+
# input_queue must not be None
|
| 65 |
+
if input_queue is None:
|
| 66 |
+
raise ValueError('input_queue is None')
|
| 67 |
+
# job_sync nodes must be sequential_nodes
|
| 68 |
+
if job_sync is not None and sequential_node == False:
|
| 69 |
+
raise ValueError('job_sync is not None and sequential_node is False')
|
| 70 |
+
# sequential_nodes should one have 1 worker
|
| 71 |
+
if sequential_node == True and num_workers != 1:
|
| 72 |
+
raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
|
| 73 |
+
# output queue must not equal input_queue
|
| 74 |
+
if output_queue == input_queue:
|
| 75 |
+
raise ValueError('output_queue must not be the same as input_queue')
|
| 76 |
+
|
| 77 |
+
node_name = node.__class__.__name__
|
| 78 |
+
if node_name not in self.nodes:
|
| 79 |
+
self.nodes.append(node_name)
|
| 80 |
+
|
| 81 |
+
# if input_queue is None then this is the root node
|
| 82 |
+
if len(self.input_queues) is 0:
|
| 83 |
+
self.root_queue = input_queue
|
| 84 |
+
|
| 85 |
+
self.input_queues.append(input_queue)
|
| 86 |
+
|
| 87 |
+
for i in range(num_workers):
|
| 88 |
+
worker_id = i
|
| 89 |
+
node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
|
| 90 |
+
self.node_workers[node_name] = node_worker
|
| 91 |
+
task = asyncio.create_task(node_worker.run())
|
| 92 |
+
self.tasks.append(task)
|
| 93 |
+
|
| 94 |
+
async def enqueue_job(self, job: Job):
|
| 95 |
+
job.id = self._job_id
|
| 96 |
+
self._job_id += 1
|
| 97 |
+
await self.root_queue.put(job)
|
| 98 |
+
|
| 99 |
+
async def close(self):
|
| 100 |
+
for task in self.tasks:
|
| 101 |
+
task.cancel()
|
| 102 |
+
await asyncio.gather(*self.tasks, return_exceptions=True)
|
| 103 |
+
|
| 104 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_pipeline.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
import unittest
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from pipeline import Pipeline, Node, Job
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Node1(Node):
|
| 13 |
+
async def process_job(self, job: Job):
|
| 14 |
+
job.data += f' (processed by node 1, worker {self.worker_id})'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Node2(Node):
|
| 18 |
+
async def process_job(self, job: Job):
|
| 19 |
+
sleep_duration = 0.08 + 0.04 * random.random()
|
| 20 |
+
await asyncio.sleep(sleep_duration)
|
| 21 |
+
job.data += f' (processed by node 2, worker {self.worker_id})'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Node3(Node):
|
| 25 |
+
async def process_job(self, job: Job):
|
| 26 |
+
job.data += f' (processed by node 3, worker {self.worker_id})'
|
| 27 |
+
print(f'{job.id} - {job.data}')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TestPipeline(unittest.TestCase):
|
| 31 |
+
def setUp(self):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
async def _test_pipeline_edge_cases(self):
|
| 35 |
+
# must have a input queue
|
| 36 |
+
with self.assertRaises(ValueError):
|
| 37 |
+
await self.pipeline.add_node(Node1, 1, None, None)
|
| 38 |
+
# too output queue must not equal from input queue
|
| 39 |
+
node1_queue = asyncio.Queue()
|
| 40 |
+
with self.assertRaises(ValueError):
|
| 41 |
+
await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
async def _test_pipeline(self, num_jobs):
|
| 45 |
+
node1_queue = asyncio.Queue()
|
| 46 |
+
node2_queue = asyncio.Queue()
|
| 47 |
+
node3_queue = asyncio.Queue()
|
| 48 |
+
await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
|
| 49 |
+
await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
|
| 50 |
+
await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
|
| 51 |
+
for i in range(num_jobs):
|
| 52 |
+
job = Job("")
|
| 53 |
+
await self.pipeline.enqueue_job(job)
|
| 54 |
+
while True:
|
| 55 |
+
if len(self.job_sync) == num_jobs:
|
| 56 |
+
break
|
| 57 |
+
await asyncio.sleep(0.1)
|
| 58 |
+
await self.pipeline.close()
|
| 59 |
+
|
| 60 |
+
def test_pipeline_edge_cases(self):
|
| 61 |
+
self.pipeline = Pipeline()
|
| 62 |
+
self.job_sync = []
|
| 63 |
+
asyncio.run(self._test_pipeline_edge_cases())
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# def test_pipeline_keeps_order(self):
|
| 67 |
+
# self.pipeline = Pipeline()
|
| 68 |
+
# self.job_sync = []
|
| 69 |
+
# num_jobs = 100
|
| 70 |
+
# start_time = time.time()
|
| 71 |
+
# asyncio.run(self._test_pipeline(num_jobs))
|
| 72 |
+
# end_time = time.time()
|
| 73 |
+
# print(f"Pipeline processed in {end_time - start_time} seconds.")
|
| 74 |
+
# self.assertEqual(len(self.job_sync), num_jobs)
|
| 75 |
+
# for i, job in enumerate(self.job_sync):
|
| 76 |
+
# self.assertEqual(i, job.id)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
unittest.main()
|
| 81 |
+
# test = TestPipeline()
|
| 82 |
+
# test.setUp()
|
| 83 |
+
# test.test_pipeline()
|