|
from typing import Iterable, Literal |
|
import sys |
|
|
|
|
|
def flatten(iterable: Iterable, depth = sys.maxsize, return_type: Literal['list', 'generator'] = 'list') -> list | Iterable: |
|
""" |
|
Flatten a nested iterable up to a specified depth. |
|
|
|
Args: |
|
iterable (iterable): The iterable to be expanded. |
|
depth (int, optional): The depth to which the iterable should be expanded. |
|
Defaults to 1. |
|
return_type (Literal['list', 'generator'], optional): The type of the return value. |
|
Defaults to 'list'. |
|
Yields: |
|
The expanded elements. |
|
""" |
|
|
|
def expand(item, current_depth=0): |
|
if current_depth == depth: |
|
yield item |
|
elif isinstance(item, (list, tuple, set)): |
|
for sub_item in item: |
|
yield from expand(sub_item, current_depth + 1) |
|
else: |
|
yield item |
|
|
|
def generator(): |
|
for item in iterable: |
|
yield from expand(item) |
|
|
|
if return_type == 'list': |
|
return list(generator()) |
|
return generator() |