Spaces:
Running
Running
import os | |
import time | |
import math | |
import asyncio | |
import logging | |
import traceback | |
from aiohttp import web | |
from typing import Dict, Union,Optional | |
from pyrogram.types import Message | |
from pyrogram import Client, utils, raw | |
from pyrogram.session import Session, Auth | |
from pyrogram.errors import AuthBytesInvalid | |
from aiohttp.http_exceptions import BadStatusLine | |
from pyrogram.file_id import FileId, FileType, ThumbnailSource | |
#---------------------Local Upload---------------------# | |
from FileStream.config import Telegram | |
from .file_properties import get_file_ids | |
from FileStream.bot import req_client, FileStream | |
from FileStream import utils, StartTime, __version__ | |
from FileStream.Tools import mime_identifier, Time_ISTKolNow | |
from FileStream.utils.FileProcessors.custom_ul import TeleUploader | |
from FileStream.Exceptions import FileNotFound, InvalidHash | |
from FileStream.bot import MULTI_CLIENTS, WORK_LOADS, ACTIVE_CLIENTS | |
class ByteStreamer: | |
def __init__(self, client: Client): | |
self.clean_timer = 30 * 60 # Cache cleanup timer set to 30 minutes | |
self.client: Client = client | |
self.cached_file_ids: Dict[str, FileId] = {} # Cache to store file properties by db_id | |
self.last_activity: float = asyncio.get_event_loop().time() # Track last activity time for the client | |
asyncio.create_task(self.clean_cache()) # Start the cache cleanup task | |
def update_last_activity(self): | |
"""Update the last activity time to the current time.""" | |
self.last_activity = asyncio.get_event_loop().time() | |
def get_last_activity(self) -> float: | |
"""Get the last activity time of this client.""" | |
return self.last_activity | |
async def get_file_properties(self, db_id: str, MULTI_CLIENTS) -> FileId: | |
""" | |
Returns the properties of a media of a specific message in a FileId class. | |
If the properties are cached, it'll return the cached results. | |
Otherwise, it'll generate the properties from the Message ID and cache them. | |
""" | |
if db_id not in self.cached_file_ids: | |
logging.debug("File properties not cached. Generating properties.") | |
await self.generate_file_properties(db_id, MULTI_CLIENTS) # Generate and cache the file properties | |
logging.debug(f"Cached file properties for file with ID {db_id}") | |
return self.cached_file_ids[db_id] | |
async def generate_file_properties(self, db_id: str, MULTI_CLIENTS) -> FileId: | |
""" | |
Generates the properties of a media file on a specific message. | |
Returns the properties in a FileId class. | |
""" | |
logging.debug("Generating file properties.") | |
file_id = await get_file_ids(self.client, db_id, Message) # Call the method to get the file properties | |
logging.debug(f"Generated file ID and Unique ID for file with ID {db_id}") | |
self.cached_file_ids[db_id] = file_id # Cache the file properties | |
logging.debug(f"Cached media file with ID {db_id}") | |
return file_id | |
async def generate_media_session(self, client: Client, file_id: FileId) -> Session: | |
""" | |
Generates the media session for the DC that contains the media file. | |
This is required for getting the bytes from Telegram servers. | |
""" | |
media_session = client.media_sessions.get(file_id.dc_id, None) | |
if media_session is None: | |
if file_id.dc_id != await client.storage.dc_id(): | |
# Create a new media session if one doesn't exist for this DC ID | |
media_session = Session( | |
client, | |
file_id.dc_id, | |
await Auth(client, file_id.dc_id, await client.storage.test_mode()).create(), | |
await client.storage.test_mode(), | |
is_media=True, | |
) | |
await media_session.start() | |
# Attempt to import authorization from Telegram's servers | |
for _ in range(6): | |
exported_auth = await client.invoke( | |
raw.functions.auth.ExportAuthorization(dc_id=file_id.dc_id)) | |
try: | |
# Import the authorization bytes for the DC | |
await media_session.invoke( | |
raw.functions.auth.ImportAuthorization( | |
id=exported_auth.id, bytes=exported_auth.bytes)) | |
break | |
except AuthBytesInvalid: | |
logging.debug(f"Invalid authorization bytes for DC {file_id.dc_id}") | |
continue | |
else: | |
await media_session.stop() | |
raise AuthBytesInvalid | |
else: | |
# Reuse the stored auth key if we're already connected to the correct DC | |
media_session = Session( | |
client, | |
file_id.dc_id, | |
await client.storage.auth_key(), | |
await client.storage.test_mode(), | |
is_media=True, | |
) | |
await media_session.start() | |
logging.debug(f"Created media session for DC {file_id.dc_id}") | |
client.media_sessions[file_id.dc_id] = media_session # Cache the media session | |
else: | |
logging.debug(f"Using cached media session for DC {file_id.dc_id}") | |
return media_session | |
async def get_location(file_id: FileId) -> Union[ | |
raw.types.InputPhotoFileLocation, | |
raw.types.InputDocumentFileLocation, | |
raw.types.InputPeerPhotoFileLocation, | |
]: | |
""" | |
Returns the file location for the media file based on its type (Photo or Document). | |
""" | |
file_type = file_id.file_type | |
if file_type == FileType.CHAT_PHOTO: | |
# Handle the case for chat photos | |
if file_id.chat_id > 0: | |
peer = raw.types.InputPeerUser(user_id=file_id.chat_id, access_hash=file_id.chat_access_hash) | |
else: | |
peer = raw.types.InputPeerChannel( | |
channel_id=utils.get_channel_id(file_id.chat_id), | |
access_hash=file_id.chat_access_hash, | |
) | |
location = raw.types.InputPeerPhotoFileLocation( | |
peer=peer, | |
volume_id=file_id.volume_id, | |
local_id=file_id.local_id, | |
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG, | |
) | |
elif file_type == FileType.PHOTO: | |
# Handle regular photos | |
location = raw.types.InputPhotoFileLocation( | |
id=file_id.media_id, | |
access_hash=file_id.access_hash, | |
file_reference=file_id.file_reference, | |
thumb_size=file_id.thumbnail_size, | |
) | |
else: | |
# Handle document files | |
location = raw.types.InputDocumentFileLocation( | |
id=file_id.media_id, | |
access_hash=file_id.access_hash, | |
file_reference=file_id.file_reference, | |
thumb_size=file_id.thumbnail_size, | |
) | |
return location | |
async def yield_file( | |
self, | |
file_id: FileId, | |
index: int, | |
offset: int, | |
first_part_cut: int, | |
last_part_cut: int, | |
part_count: int, | |
chunk_size: int, | |
) -> Union[str, None]: | |
""" | |
Yields the file in chunks based on the specified range and chunk size. | |
This method streams the file from Telegram's server, breaking it into smaller parts. | |
""" | |
client = self.client | |
WORK_LOADS[index] += 1 # Increase the workload for this client | |
logging.debug(f"Starting to yield file with client {index}.") | |
media_session = await self.generate_media_session(client, file_id) | |
current_part = 1 | |
location = await self.get_location(file_id) | |
try: | |
# Fetch the file chunks | |
r = await media_session.invoke( | |
raw.functions.upload.GetFile(location=location, offset=offset, limit=chunk_size), ) | |
if isinstance(r, raw.types.upload.File): | |
# Stream the file in chunks | |
while True: | |
chunk = r.bytes | |
if not chunk: | |
break | |
elif part_count == 1: | |
yield chunk[first_part_cut:last_part_cut] | |
elif current_part == 1: | |
yield chunk[first_part_cut:] | |
elif current_part == part_count: | |
yield chunk[:last_part_cut] | |
else: | |
yield chunk | |
current_part += 1 | |
offset += chunk_size | |
if current_part > part_count: | |
break | |
r = await media_session.invoke( | |
raw.functions.upload.GetFile(location=location, offset=offset, limit=chunk_size), ) | |
except (TimeoutError, AttributeError): | |
pass | |
except Exception as e: | |
logging.info(f"Error at Bytestreamer Generating Chunk : {e}") | |
finally: | |
logging.debug(f"Finished yielding file with {current_part} parts.") | |
WORK_LOADS[index] -= 1 # Decrease the workload for this client | |
async def clean_cache(self) -> None: | |
""" | |
Function to clean the cache to reduce memory usage. | |
This method will be called periodically to clear the cached file properties. | |
""" | |
await asyncio.sleep(self.clean_timer) # Wait for the cleanup interval | |
logging.info("*** Cleaning cached file IDs...") | |
self.cached_file_ids.clear() # Clear the cache | |
logging.debug("Cache cleaned.") | |