import asyncio import logging from typing import Dict, Union from pyrogram.types import Message from pyrogram import Client, utils, raw from pyrogram.session import Session, Auth from pyrogram.errors import AuthBytesInvalid from pyrogram.file_id import FileId, FileType, ThumbnailSource #--------------------Local Imports -------------------------------# from .file_properties import get_file_ids from FileStream.bot import WORK_LOADS class ByteStreamer: def __init__(self, client: Client): self.clean_timer = 30 * 60 # Cache cleanup timer set to 30 minutes self.client: Client = client self.cached_file_ids: Dict[str, FileId] = {} # Cache to store file properties by db_id asyncio.create_task(self.clean_cache()) # Start the cache cleanup task async def get_file_properties(self, db_id: str, MULTI_CLIENTS) -> FileId: """ Returns the properties of a media of a specific message in a FileId class. If the properties are cached, it'll return the cached results. Otherwise, it'll generate the properties from the Message ID and cache them. """ if db_id not in self.cached_file_ids: logging.debug("File properties not cached. Generating properties.") await self.generate_file_properties(db_id, MULTI_CLIENTS) # Generate and cache the file properties logging.debug(f"Cached file properties for file with ID {db_id}") return self.cached_file_ids[db_id] async def generate_file_properties(self, db_id: str, MULTI_CLIENTS) -> FileId: """ Generates the properties of a media file on a specific message. Returns the properties in a FileId class. """ logging.debug("Generating file properties.") file_id = await get_file_ids(self.client, db_id, Message) # Call the method to get the file properties logging.debug(f"Generated file ID and Unique ID for file with ID {db_id}") self.cached_file_ids[db_id] = file_id # Cache the file properties logging.debug(f"Cached media file with ID {db_id}") return file_id async def generate_media_session(self, client: Client, file_id: FileId) -> Session: """ Generates the media session for the DC that contains the media file. This is required for getting the bytes from Telegram servers. """ media_session = client.media_sessions.get(file_id.dc_id, None) if media_session is None: if file_id.dc_id != await client.storage.dc_id(): # Create a new media session if one doesn't exist for this DC ID media_session = Session( client, file_id.dc_id, await Auth(client, file_id.dc_id, await client.storage.test_mode()).create(), await client.storage.test_mode(), is_media=True, ) await media_session.start() # Attempt to import authorization from Telegram's servers for _ in range(6): exported_auth = await client.invoke( raw.functions.auth.ExportAuthorization(dc_id=file_id.dc_id)) try: # Import the authorization bytes for the DC await media_session.invoke( raw.functions.auth.ImportAuthorization( id=exported_auth.id, bytes=exported_auth.bytes)) break except AuthBytesInvalid: logging.debug(f"Invalid authorization bytes for DC {file_id.dc_id}") continue else: await media_session.stop() raise AuthBytesInvalid else: # Reuse the stored auth key if we're already connected to the correct DC media_session = Session( client, file_id.dc_id, await client.storage.auth_key(), await client.storage.test_mode(), is_media=True, ) await media_session.start() logging.debug(f"Created media session for DC {file_id.dc_id}") client.media_sessions[file_id.dc_id] = media_session # Cache the media session else: logging.debug(f"Using cached media session for DC {file_id.dc_id}") return media_session @staticmethod async def get_location(file_id: FileId) -> Union[ raw.types.InputPhotoFileLocation, raw.types.InputDocumentFileLocation, raw.types.InputPeerPhotoFileLocation, ]: """ Returns the file location for the media file based on its type (Photo or Document). """ file_type = file_id.file_type if file_type == FileType.CHAT_PHOTO: # Handle the case for chat photos if file_id.chat_id > 0: peer = raw.types.InputPeerUser(user_id=file_id.chat_id, access_hash=file_id.chat_access_hash) else: peer = raw.types.InputPeerChannel( channel_id=utils.get_channel_id(file_id.chat_id), access_hash=file_id.chat_access_hash, ) location = raw.types.InputPeerPhotoFileLocation( peer=peer, volume_id=file_id.volume_id, local_id=file_id.local_id, big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG, ) elif file_type == FileType.PHOTO: # Handle regular photos location = raw.types.InputPhotoFileLocation( id=file_id.media_id, access_hash=file_id.access_hash, file_reference=file_id.file_reference, thumb_size=file_id.thumbnail_size, ) else: # Handle document files location = raw.types.InputDocumentFileLocation( id=file_id.media_id, access_hash=file_id.access_hash, file_reference=file_id.file_reference, thumb_size=file_id.thumbnail_size, ) return location async def yield_file( self, file_id: FileId, index: int, offset: int, first_part_cut: int, last_part_cut: int, part_count: int, chunk_size: int, ) -> Union[str, None]: """ Yields the file in chunks based on the specified range and chunk size. This method streams the file from Telegram's server, breaking it into smaller parts. """ client = self.client WORK_LOADS[index] += 1 # Increase the workload for this client logging.debug(f"Starting to yield file with client {index}.") media_session = await self.generate_media_session(client, file_id) current_part = 1 location = await self.get_location(file_id) try: # Fetch the file chunks r = await media_session.invoke( raw.functions.upload.GetFile(location=location, offset=offset, limit=chunk_size), ) if isinstance(r, raw.types.upload.File): # Stream the file in chunks while True: chunk = r.bytes if not chunk: break elif part_count == 1: yield chunk[first_part_cut:last_part_cut] elif current_part == 1: yield chunk[first_part_cut:] elif current_part == part_count: yield chunk[:last_part_cut] else: yield chunk current_part += 1 offset += chunk_size if current_part > part_count: break r = await media_session.invoke( raw.functions.upload.GetFile(location=location, offset=offset, limit=chunk_size), ) except (TimeoutError, AttributeError): pass finally: logging.debug(f"Finished yielding file with {current_part} parts.") WORK_LOADS[index] -= 1 # Decrease the workload for this client async def clean_cache(self) -> None: """ Function to clean the cache to reduce memory usage. This method will be called periodically to clear the cached file properties. """ while True: await asyncio.sleep(self.clean_timer) # Wait for the cleanup interval logging.debug("Cleaning cached file properties...") self.cached_file_ids.clear() # Clear the cache logging.debug("Cache cleaned.")