Spaces:
Running
Running
import os | |
import re | |
import math | |
import time | |
import asyncio | |
import pyrogram | |
import traceback | |
import functools | |
import logging | |
import inspect | |
from hashlib import md5 | |
from datetime import datetime | |
from typing import Union, BinaryIO, List, Optional, Callable, Dict | |
from pyrogram import Client, utils, raw | |
from pyrogram.session import Session, Auth | |
from pyrogram.errors import AuthBytesInvalid | |
from pyrogram.file_id import FileId, FileType, ThumbnailSource | |
from pyrogram.types import Message | |
from pyrogram import StopTransmission | |
from pyrogram import Client, utils, raw | |
from pyrogram.session import Session, Auth | |
from pyrogram.errors import AuthBytesInvalid | |
from pyrogram.file_id import FileId, FileType, ThumbnailSource | |
from typing import Union, BinaryIO, List, Optional, Callable | |
from pyrogram.types import Message | |
from pyrogram import raw | |
from pyrogram import types | |
from pyrogram import utils as pgutils | |
from pyrogram import StopTransmission, enums | |
from pyrogram.errors import FilePartMissing | |
from pyrogram.file_id import FileType | |
from pyrogram.enums import ParseMode, ChatType | |
from pyrogram import filters, Client | |
from pyrogram.errors import FloodWait | |
from pyrogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton | |
from pyrogram.enums.parse_mode import ParseMode | |
#-----------------------Local Imports----------------------------------# | |
from FileStream import utils, StartTime, __version__ | |
from FileStream.bot import FileStream, MULTI_CLIENTS, WORK_LOADS | |
from FileStream.utils.FileProcessors.bot_utils import is_user_banned, is_user_exist, is_user_joined, gen_link, is_channel_banned, is_channel_exist, is_user_authorized | |
from FileStream.Database import Database | |
from FileStream.utils.FileProcessors.file_properties import get_file_ids, get_file_info | |
from FileStream.Tools.tool import TimeFormatter | |
from FileStream.config import Telegram | |
from .file_properties import get_file_ids | |
class TGFileController: | |
def __init__(self, client: Client): | |
self.clean_timer = 30 * 60 | |
self.client: Client = client | |
self.cached_file_ids: Dict[str, FileId] = {} | |
asyncio.create_task(self.clean_cache()) | |
async def clean_cache(self) -> None: | |
""" | |
function to clean the cache to reduce memory usage | |
""" | |
while True: | |
await asyncio.sleep(self.clean_timer) | |
print("** Caches Cleared :", self.cached_file_ids) | |
self.cached_file_ids.clear() | |
print("Cleaned the cache") | |
logging.debug("Cleaned the cache") | |
async def get_me(self): | |
return await self.client.get_me().username | |
async def get_file_properties(self, db_id: str, MULTI_CLIENTS) -> FileId: | |
""" | |
Returns the properties of a media of a specific message in a FIleId class. | |
if the properties are cached, then it'll return the cached results. | |
or it'll generate the properties from the Message ID and cache them. | |
""" | |
if not db_id in self.cached_file_ids: | |
logging.debug("Before Calling generate_file_properties") | |
await self.generate_file_properties(db_id, MULTI_CLIENTS) | |
logging.debug(f"Cached file properties for file with ID {db_id}") | |
return self.cached_file_ids[db_id] | |
async def generate_file_properties(self, db_id: str,MULTI_CLIENTS) -> FileId: | |
""" | |
Generates the properties of a media file on a specific message. | |
returns ths properties in a FIleId class. | |
""" | |
logging.debug("Before calling get_file_ids") | |
file_id = await get_file_ids(self.client, db_id, Message) | |
logging.debug(f"Generated file ID and Unique ID for file with ID {db_id}") | |
self.cached_file_ids[db_id] = file_id | |
logging.debug(f"Cached media file with ID {db_id}") | |
return self.cached_file_ids[db_id] | |
async def generate_media_session(self, client: Client,file_id: FileId) -> Session: | |
""" | |
Generates the media session for the DC that contains the media file. | |
This is required for getting the bytes from Telegram servers. | |
""" | |
media_session = client.media_sessions.get(file_id.dc_id, None) | |
if media_session is None: | |
if file_id.dc_id != await client.storage.dc_id(): | |
media_session = Session( | |
client, | |
file_id.dc_id, | |
await Auth(client, file_id.dc_id, await | |
client.storage.test_mode()).create(), | |
await client.storage.test_mode(), | |
is_media=True, | |
) | |
await media_session.start() | |
for _ in range(6): | |
exported_auth = await client.invoke( | |
raw.functions.auth.ExportAuthorization(dc_id=file_id.dc_id)) | |
try: | |
await media_session.invoke( | |
raw.functions.auth.ImportAuthorization( | |
id=exported_auth.id, bytes=exported_auth.bytes)) | |
break | |
except AuthBytesInvalid: | |
logging.debug( | |
f"Invalid authorization bytes for DC {file_id.dc_id}") | |
continue | |
else: | |
await media_session.stop() | |
raise AuthBytesInvalid | |
else: | |
media_session = Session( | |
client, | |
file_id.dc_id, | |
await client.storage.auth_key(), | |
await client.storage.test_mode(), | |
is_media=True, | |
) | |
await media_session.start() | |
logging.debug(f"Created media session for DC {file_id.dc_id}") | |
client.media_sessions[file_id.dc_id] = media_session | |
else: | |
logging.debug(f"Using cached media session for DC {file_id.dc_id}") | |
return media_session | |
async def get_location(file_id: FileId) -> Union[ | |
raw.types.InputPhotoFileLocation, | |
raw.types.InputDocumentFileLocation, | |
raw.types.InputPeerPhotoFileLocation, | |
]: | |
""" | |
Returns the file location for the media file. | |
""" | |
file_type = file_id.file_type | |
if file_type == FileType.CHAT_PHOTO: | |
if file_id.chat_id > 0: | |
peer = raw.types.InputPeerUser(user_id=file_id.chat_id, | |
access_hash=file_id.chat_access_hash) | |
else: | |
if file_id.chat_access_hash == 0: | |
peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id) | |
else: | |
peer = raw.types.InputPeerChannel( | |
channel_id=utils.get_channel_id(file_id.chat_id), | |
access_hash=file_id.chat_access_hash, | |
) | |
location = raw.types.InputPeerPhotoFileLocation( | |
peer=peer, | |
volume_id=file_id.volume_id, | |
local_id=file_id.local_id, | |
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG, | |
) | |
elif file_type == FileType.PHOTO: | |
location = raw.types.InputPhotoFileLocation( | |
id=file_id.media_id, | |
access_hash=file_id.access_hash, | |
file_reference=file_id.file_reference, | |
thumb_size=file_id.thumbnail_size, | |
) | |
else: | |
location = raw.types.InputDocumentFileLocation( | |
id=file_id.media_id, | |
access_hash=file_id.access_hash, | |
file_reference=file_id.file_reference, | |
thumb_size=file_id.thumbnail_size, | |
) | |
return location | |
async def send(self, | |
media: Union["types.InputMedia", "types.InputMediaPhoto", | |
"types.InputMediaVideo", "types.InputMediaAudio", | |
"types.InputMediaAnimation", | |
"types.InputMediaDocument", | |
"types.InputPhoneContact"], caption: str, | |
reply_to_msg_id: str, chat_id: Union[int, str]): | |
client = self.client | |
#while True: | |
try: | |
if reply_to_msg_id: | |
r = await client.invoke( | |
raw.functions.messages.SendMedia( | |
peer=await client.resolve_peer(chat_id), | |
media=media, | |
message=caption, | |
reply_to_msg_id=reply_to_msg_id, | |
random_id=client.rnd_id(), | |
)) | |
else: | |
print("This is the reply_to_msg_id") | |
r = await client.invoke( | |
raw.functions.messages.SendMedia( | |
peer=await client.resolve_peer(chat_id), | |
media=media, | |
message=caption, | |
random_id=client.rnd_id(), | |
)) | |
except Exception as e: | |
await client.send_message(chat_id=Telegram.ULOG_CHANNEL, | |
text=f"**#EʀʀᴏʀTʀᴀᴄᴋᴇʙᴀᴄᴋ:** `{e}`", | |
disable_web_page_preview=True) | |
print( | |
f"Cᴀɴ'ᴛ Eᴅɪᴛ Bʀᴏᴀᴅᴄᴀsᴛ Mᴇssᴀɢᴇ!\nEʀʀᴏʀ: **Gɪᴠᴇ ᴍᴇ ᴇᴅɪᴛ ᴘᴇʀᴍɪssɪᴏɴ ɪɴ ᴜᴘᴅᴀᴛᴇs ᴀɴᴅ ʙɪɴ Cʜᴀɴɴᴇʟ!{traceback.format_exc()}**" | |
) | |
else: | |
for i in r.updates: | |
if isinstance( | |
i, (raw.types.UpdateNewMessage, raw.types.UpdateNewChannelMessage, | |
raw.types.UpdateNewScheduledMessage)): | |
return await types.Message._parse( | |
client, | |
i.message, {i.id: i | |
for i in r.users}, {i.id: i | |
for i in r.chats}, | |
is_scheduled=isinstance(i, raw.types.UpdateNewScheduledMessage)) | |
async def upload_file( | |
self, | |
index: int, | |
file_id: FileId, | |
file_name: str, | |
file_size: int, | |
progress: Callable, | |
progress_args: tuple = ()) -> Union[str, None]: | |
async def worker(session): | |
while True: | |
data = await queue.get() | |
if data is None: | |
return | |
try: | |
await session.invoke(data) | |
except Exception as e: | |
log.exception(e) | |
client = self.client | |
part_size = 512 * 1024 | |
file_id = file_id | |
index = index | |
chunk_size = part_size | |
file_name = file_name | |
file_size = file_size | |
if file_size == 0: | |
raise ValueError("File size equals to 0 B") | |
file_size_limit_mib = 4000 if client.me.is_premium else 2000 | |
#file_size_limit_mib = 4000 | |
if file_size > file_size_limit_mib * 1024 * 1024: | |
raise ValueError( | |
f"Can't upload files bigger than {file_size_limit_mib} MiB") | |
file_total_parts = int(math.ceil(file_size / part_size)) | |
is_big = file_size > 10 * 1024 * 1024 | |
workers_count = 4 if is_big else 1 | |
#is_missing_part = file_id is not None | |
new_file_id = client.rnd_id() | |
if not is_big: | |
md5_sum = md5() | |
session = Session(client, | |
await client.storage.dc_id(), | |
await client.storage.auth_key(), | |
await client.storage.test_mode(), | |
is_media=True) | |
queue = asyncio.Queue(1) | |
workers = [ | |
client.loop.create_task(worker(session)) for _ in range(workers_count) | |
] | |
WORK_LOADS[index] += 1 | |
logging.debug(f"Starting to yielding file with client {index}.") | |
media_session = await self.generate_media_session(client, file_id) | |
file_part = 0 | |
offset = 0 | |
location = await self.get_location(file_id) | |
try: | |
await session.start() | |
r = await media_session.invoke( | |
raw.functions.upload.GetFile(location=location, | |
offset=offset, | |
limit=chunk_size), ) | |
if isinstance(r, raw.types.upload.File): | |
while True: | |
chunk = r.bytes | |
if not chunk: | |
break | |
if not is_big: | |
md5_sum.update(chunk) | |
if is_big: | |
rpc = raw.functions.upload.SaveBigFilePart( | |
file_id=new_file_id, | |
file_part=file_part, | |
file_total_parts=file_total_parts, | |
bytes=chunk) | |
else: | |
rpc = raw.functions.upload.SaveFilePart(file_id=new_file_id, | |
file_part=file_part, | |
bytes=chunk) | |
await queue.put(rpc) | |
file_part += 1 | |
offset += chunk_size | |
if file_part >= file_total_parts: | |
break | |
r = await media_session.invoke( | |
raw.functions.upload.GetFile(location=location, | |
offset=offset, | |
limit=chunk_size), ) | |
if progress: | |
func = functools.partial(progress, | |
min(file_part * part_size, file_size), | |
file_size, progress_args) | |
if inspect.iscoroutinefunction(progress): | |
await func() | |
else: | |
await client.loop.run_in_executor(client.executor, func) | |
except (TimeoutError, AttributeError): | |
pass | |
finally: | |
for _ in workers: | |
await queue.put(None) | |
await asyncio.gather(*workers) | |
await session.stop() | |
logging.debug(f"Finished yielding file with {file_part} parts.") | |
WORK_LOADS[index] -= 1 | |
if is_big: | |
return raw.types.InputFileBig( | |
id=new_file_id, | |
parts=file_total_parts, | |
name=file_name, | |
) | |
else: | |
return raw.types.InputFile(id=new_file_id, | |
parts=file_total_parts, | |
name=file_name, | |
md5_checksum=md5_sum) | |