Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os | |
| import sys | |
| import time | |
| import logging | |
| import paramiko | |
| import hashlib | |
| import tempfile | |
| from datetime import datetime | |
| import sqlite3 | |
| import re | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("data_transfer.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| REMOTE_HOST = "86.38.238.117" | |
| REMOTE_USER = "root" # Replace with your actual username | |
| REMOTE_KEY_PATH = "~/.ssh/id_rsa" # Replace with path to your SSH key | |
| REMOTE_DATA_DIR = "/root/neuralos-demo/train_dataset_encoded_online" # Replace with actual path | |
| LOCAL_DATA_DIR = "./train_dataset_encoded_online" # Local destination | |
| DB_FILE = "transfer_state.db" | |
| POLL_INTERVAL = 300 # Check for new files every 5 minutes | |
| # Ensure local directories exist | |
| os.makedirs(LOCAL_DATA_DIR, exist_ok=True) | |
| def initialize_database(): | |
| """Create and initialize the SQLite database to track transferred files.""" | |
| conn = sqlite3.connect(DB_FILE) | |
| cursor = conn.cursor() | |
| # Create tables if they don't exist | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS transferred_files ( | |
| id INTEGER PRIMARY KEY, | |
| filename TEXT UNIQUE, | |
| remote_size INTEGER, | |
| remote_mtime REAL, | |
| transfer_time TIMESTAMP, | |
| checksum TEXT | |
| ) | |
| ''') | |
| # Table for tracking last successful CSV/PKL transfer | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS transfer_state ( | |
| key TEXT PRIMARY KEY, | |
| value TEXT | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| def is_file_transferred(filename, remote_size, remote_mtime): | |
| """Check if a file has already been transferred with the same size and mtime.""" | |
| conn = sqlite3.connect(DB_FILE) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT 1 FROM transferred_files WHERE filename = ? AND remote_size = ? AND remote_mtime = ?", | |
| (filename, remote_size, remote_mtime) | |
| ) | |
| result = cursor.fetchone() is not None | |
| conn.close() | |
| return result | |
| def mark_file_transferred(filename, remote_size, remote_mtime, checksum): | |
| """Mark a file as successfully transferred.""" | |
| conn = sqlite3.connect(DB_FILE) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """INSERT OR REPLACE INTO transferred_files | |
| (filename, remote_size, remote_mtime, transfer_time, checksum) | |
| VALUES (?, ?, ?, ?, ?)""", | |
| (filename, remote_size, remote_mtime, datetime.now().isoformat(), checksum) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def update_transfer_state(key, value): | |
| """Update the transfer state for a key.""" | |
| conn = sqlite3.connect(DB_FILE) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT OR REPLACE INTO transfer_state (key, value) VALUES (?, ?)", | |
| (key, value) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def get_transfer_state(key): | |
| """Get the transfer state for a key.""" | |
| conn = sqlite3.connect(DB_FILE) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT value FROM transfer_state WHERE key = ?", (key,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| return result[0] if result else None | |
| def calculate_checksum(file_path): | |
| """Calculate MD5 checksum of a file.""" | |
| md5 = hashlib.md5() | |
| with open(file_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(4096), b''): | |
| md5.update(chunk) | |
| return md5.hexdigest() | |
| def create_ssh_client(): | |
| """Create and return an SSH client connected to the remote server.""" | |
| client = paramiko.SSHClient() | |
| client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| # Expand the key path | |
| key_path = os.path.expanduser(REMOTE_KEY_PATH) | |
| try: | |
| key = paramiko.RSAKey.from_private_key_file(key_path) | |
| client.connect( | |
| hostname=REMOTE_HOST, | |
| username=REMOTE_USER, | |
| pkey=key | |
| ) | |
| logger.info(f"Successfully connected to {REMOTE_USER}@{REMOTE_HOST}") | |
| return client | |
| except Exception as e: | |
| logger.error(f"Failed to connect to {REMOTE_HOST}: {str(e)}") | |
| raise | |
| def safe_transfer_file(sftp, remote_path, local_path): | |
| """ | |
| Transfer a file safely using a temporary file and rename. | |
| Returns the checksum of the transferred file. | |
| """ | |
| # Create a temporary file for download | |
| temp_file = local_path + ".tmp" | |
| try: | |
| # Transfer to temporary file | |
| sftp.get(remote_path, temp_file) | |
| # Calculate checksum | |
| checksum = calculate_checksum(temp_file) | |
| # Rename to final destination | |
| os.rename(temp_file, local_path) | |
| logger.info(f"Successfully transferred {remote_path} to {local_path}") | |
| return checksum | |
| except Exception as e: | |
| logger.error(f"Error transferring {remote_path}: {str(e)}") | |
| # Clean up temp file if it exists | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| raise | |
| def is_file_stable(sftp, remote_path, wait_time=30): | |
| """ | |
| Check if a file is stable (not being written to) by comparing its size | |
| before and after a short wait period. | |
| """ | |
| try: | |
| # Get initial stats | |
| initial_stat = sftp.stat(remote_path) | |
| initial_size = initial_stat.st_size | |
| # Wait a bit | |
| time.sleep(wait_time) | |
| # Get updated stats | |
| updated_stat = sftp.stat(remote_path) | |
| updated_size = updated_stat.st_size | |
| # File is stable if size hasn't changed | |
| is_stable = initial_size == updated_size | |
| if not is_stable: | |
| logger.info(f"File {remote_path} is still being written to (size changed from {initial_size} to {updated_size})") | |
| return is_stable, updated_stat | |
| except Exception as e: | |
| logger.error(f"Error checking if {remote_path} is stable: {str(e)}") | |
| return False, None | |
| def transfer_tar_files(sftp): | |
| """Transfer all record_*.tar files that haven't been transferred yet.""" | |
| transferred_count = 0 | |
| try: | |
| # List all tar files | |
| tar_pattern = re.compile(r'record_.*\.tar$') | |
| remote_files = sftp.listdir(REMOTE_DATA_DIR) | |
| tar_files = [f for f in remote_files if tar_pattern.match(f)] | |
| logger.info(f"Found {len(tar_files)} TAR files on remote server") | |
| for tar_file in tar_files: | |
| remote_path = os.path.join(REMOTE_DATA_DIR, tar_file) | |
| local_path = os.path.join(LOCAL_DATA_DIR, tar_file) | |
| # Get file stats | |
| try: | |
| stat = sftp.stat(remote_path) | |
| except FileNotFoundError: | |
| logger.warning(f"File {remote_path} disappeared, skipping") | |
| continue | |
| # Skip if already transferred with same size and mtime | |
| if is_file_transferred(tar_file, stat.st_size, stat.st_mtime): | |
| logger.debug(f"Skipping already transferred file: {tar_file}") | |
| continue | |
| # Check if file is stable (not being written to) | |
| is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
| if not is_stable: | |
| logger.info(f"Skipping unstable file: {tar_file}") | |
| continue | |
| # Transfer the file | |
| try: | |
| checksum = safe_transfer_file(sftp, remote_path, local_path) | |
| mark_file_transferred(tar_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| transferred_count += 1 | |
| except Exception as e: | |
| logger.error(f"Failed to transfer {tar_file}: {str(e)}") | |
| continue | |
| logger.info(f"Transferred {transferred_count} new TAR files") | |
| return transferred_count | |
| except Exception as e: | |
| logger.error(f"Error in transfer_tar_files: {str(e)}") | |
| return 0 | |
| def transfer_pkl_file(sftp): | |
| """Transfer the PKL file if it hasn't been transferred yet or has changed.""" | |
| pkl_file = "image_action_mapping_with_key_states.pkl" | |
| remote_path = os.path.join(REMOTE_DATA_DIR, pkl_file) | |
| local_path = os.path.join(LOCAL_DATA_DIR, pkl_file) | |
| try: | |
| # Check if file exists | |
| try: | |
| stat = sftp.stat(remote_path) | |
| except FileNotFoundError: | |
| logger.warning(f"PKL file {remote_path} not found") | |
| return False | |
| # Skip if already transferred with same size and mtime | |
| if is_file_transferred(pkl_file, stat.st_size, stat.st_mtime): | |
| logger.debug(f"Skipping already transferred PKL file (unchanged)") | |
| return True | |
| # Check if file is stable | |
| is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
| if not is_stable: | |
| logger.info(f"PKL file is still being written to, skipping") | |
| return False | |
| # Transfer the file | |
| checksum = safe_transfer_file(sftp, remote_path, local_path) | |
| mark_file_transferred(pkl_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| # Update state | |
| update_transfer_state("last_pkl_transfer", datetime.now().isoformat()) | |
| logger.info(f"Successfully transferred PKL file") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error transferring PKL file: {str(e)}") | |
| return False | |
| def transfer_csv_file(sftp): | |
| """Transfer the CSV file if it hasn't been transferred yet or has changed.""" | |
| csv_file = "train_dataset.target_frames.csv" | |
| remote_path = os.path.join(REMOTE_DATA_DIR, csv_file) | |
| local_path = os.path.join(LOCAL_DATA_DIR, csv_file) | |
| try: | |
| # Check if file exists | |
| try: | |
| stat = sftp.stat(remote_path) | |
| except FileNotFoundError: | |
| logger.warning(f"CSV file {remote_path} not found") | |
| return False | |
| # Skip if already transferred with same size and mtime | |
| if is_file_transferred(csv_file, stat.st_size, stat.st_mtime): | |
| logger.debug(f"Skipping already transferred CSV file (unchanged)") | |
| return True | |
| # Check if file is stable | |
| is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
| if not is_stable: | |
| logger.info(f"CSV file is still being written to, skipping") | |
| return False | |
| # Transfer the file | |
| checksum = safe_transfer_file(sftp, remote_path, local_path) | |
| mark_file_transferred(csv_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| # Update state | |
| update_transfer_state("last_csv_transfer", datetime.now().isoformat()) | |
| logger.info(f"Successfully transferred CSV file") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error transferring CSV file: {str(e)}") | |
| return False | |
| def transfer_padding_file(sftp): | |
| """Transfer the padding.npy file if it hasn't been transferred yet or has changed.""" | |
| padding_file = "padding.npy" | |
| remote_path = os.path.join(REMOTE_DATA_DIR, padding_file) | |
| local_path = os.path.join(LOCAL_DATA_DIR, padding_file) | |
| try: | |
| # Check if file exists | |
| try: | |
| stat = sftp.stat(remote_path) | |
| except FileNotFoundError: | |
| logger.warning(f"Padding file {remote_path} not found") | |
| return False | |
| # Skip if already transferred with same size and mtime | |
| if is_file_transferred(padding_file, stat.st_size, stat.st_mtime): | |
| logger.debug(f"Skipping already transferred padding file (unchanged)") | |
| return True | |
| # Check if file is stable | |
| is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
| if not is_stable: | |
| logger.info(f"Padding file is still being written to, skipping") | |
| return False | |
| # Transfer the file | |
| checksum = safe_transfer_file(sftp, remote_path, local_path) | |
| mark_file_transferred(padding_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| # Update state | |
| update_transfer_state("last_padding_transfer", datetime.now().isoformat()) | |
| logger.info(f"Successfully transferred padding.npy file") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error transferring padding file: {str(e)}") | |
| return False | |
| def run_transfer_cycle(): | |
| """Run a complete transfer cycle with time-based consistency.""" | |
| client = None | |
| try: | |
| # Connect to the remote server | |
| client = create_ssh_client() | |
| sftp = client.open_sftp() | |
| # Step 0: Get a snapshot of all files with their timestamps | |
| # This creates a consistent view of the remote directory at this point in time | |
| logger.info("Taking snapshot of remote directory state") | |
| remote_files = {} | |
| for filename in sftp.listdir(REMOTE_DATA_DIR): | |
| try: | |
| file_path = os.path.join(REMOTE_DATA_DIR, filename) | |
| stat = sftp.stat(file_path) | |
| remote_files[filename] = { | |
| 'size': stat.st_size, | |
| 'mtime': stat.st_mtime, | |
| 'path': file_path | |
| } | |
| except Exception as e: | |
| logger.warning(f"Could not stat file {filename}: {str(e)}") | |
| logger.info(f"Found {len(remote_files)} files in remote directory snapshot") | |
| # Step 1: Transfer padding.npy file if needed | |
| if "padding.npy" in remote_files: | |
| file_info = remote_files["padding.npy"] | |
| if not is_file_transferred("padding.npy", file_info['size'], file_info['mtime']): | |
| # Check stability | |
| is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
| if is_stable: | |
| local_path = os.path.join(LOCAL_DATA_DIR, "padding.npy") | |
| checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
| mark_file_transferred("padding.npy", updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| logger.info("Successfully transferred padding.npy file") | |
| else: | |
| logger.warning("Padding file is still being written, skipping") | |
| else: | |
| logger.warning("padding.npy not found in remote directory") | |
| # Step 2: Transfer TAR files from the snapshot | |
| tar_pattern = re.compile(r'record_.*\.tar$') | |
| tar_files = {name: info for name, info in remote_files.items() if tar_pattern.match(name)} | |
| logger.info(f"Found {len(tar_files)} TAR files in snapshot") | |
| tar_count = 0 | |
| for tar_file, file_info in tar_files.items(): | |
| # Skip if already transferred with same size and mtime | |
| if is_file_transferred(tar_file, file_info['size'], file_info['mtime']): | |
| logger.debug(f"Skipping already transferred file: {tar_file}") | |
| continue | |
| # Check if file is stable | |
| is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
| if not is_stable: | |
| logger.info(f"Skipping unstable file: {tar_file}") | |
| continue | |
| # Transfer the file | |
| try: | |
| local_path = os.path.join(LOCAL_DATA_DIR, tar_file) | |
| checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
| mark_file_transferred(tar_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| tar_count += 1 | |
| except Exception as e: | |
| logger.error(f"Failed to transfer {tar_file}: {str(e)}") | |
| logger.info(f"Transferred {tar_count} new TAR files from snapshot") | |
| # Step 3: Transfer PKL file from the snapshot | |
| pkl_file = "image_action_mapping_with_key_states.pkl" | |
| if pkl_file in remote_files: | |
| file_info = remote_files[pkl_file] | |
| # Only transfer if needed | |
| if not is_file_transferred(pkl_file, file_info['size'], file_info['mtime']): | |
| is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
| if is_stable: | |
| local_path = os.path.join(LOCAL_DATA_DIR, pkl_file) | |
| checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
| mark_file_transferred(pkl_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| update_transfer_state("last_pkl_transfer", datetime.now().isoformat()) | |
| logger.info("Successfully transferred PKL file from snapshot") | |
| pkl_success = True | |
| else: | |
| logger.warning("PKL file is still being written, skipping") | |
| pkl_success = False | |
| else: | |
| logger.debug("PKL file unchanged, skipping") | |
| pkl_success = True | |
| else: | |
| logger.warning("PKL file not found in snapshot") | |
| pkl_success = False | |
| # Step 4: Transfer CSV file from the snapshot (only if PKL succeeded) | |
| csv_file = "train_dataset.target_frames.csv" | |
| if pkl_success and csv_file in remote_files: | |
| file_info = remote_files[csv_file] | |
| # Only transfer if needed | |
| if not is_file_transferred(csv_file, file_info['size'], file_info['mtime']): | |
| is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
| if is_stable: | |
| local_path = os.path.join(LOCAL_DATA_DIR, csv_file) | |
| checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
| mark_file_transferred(csv_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
| update_transfer_state("last_csv_transfer", datetime.now().isoformat()) | |
| logger.info("Successfully transferred CSV file from snapshot") | |
| csv_success = True | |
| else: | |
| logger.warning("CSV file is still being written, skipping") | |
| csv_success = False | |
| else: | |
| logger.debug("CSV file unchanged, skipping") | |
| csv_success = True | |
| else: | |
| if not pkl_success: | |
| logger.warning("Skipping CSV transfer because PKL transfer failed") | |
| else: | |
| logger.warning("CSV file not found in snapshot") | |
| csv_success = False | |
| return tar_count > 0 or pkl_success or csv_success | |
| except Exception as e: | |
| logger.error(f"Error in transfer cycle: {str(e)}") | |
| return False | |
| finally: | |
| if client: | |
| client.close() | |
| def main(): | |
| """Main function for the data transfer script.""" | |
| logger.info("Starting data transfer script") | |
| # Initialize the database | |
| initialize_database() | |
| try: | |
| while True: | |
| logger.info("Starting new transfer cycle") | |
| changes = run_transfer_cycle() | |
| if changes: | |
| logger.info("Transfer cycle completed with new files transferred") | |
| else: | |
| logger.info("Transfer cycle completed with no changes") | |
| logger.info(f"Sleeping for {POLL_INTERVAL} seconds before next check") | |
| time.sleep(POLL_INTERVAL) | |
| except KeyboardInterrupt: | |
| logger.info("Script terminated by user") | |
| except Exception as e: | |
| logger.error(f"Unhandled exception: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |