Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import sys | |
import time | |
import logging | |
import paramiko | |
import hashlib | |
import tempfile | |
from datetime import datetime | |
import sqlite3 | |
import re | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("data_transfer.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Configuration | |
REMOTE_HOST = "neural-os.com" | |
REMOTE_USER = "root" # Replace with your actual username | |
REMOTE_KEY_PATH = "~/.ssh/id_rsa" # Replace with path to your SSH key | |
REMOTE_DATA_DIR = "/root/neuralos-demo/train_dataset_encoded_online" # Replace with actual path | |
LOCAL_DATA_DIR = "./train_dataset_encoded_online" # Local destination | |
DB_FILE = "transfer_state.db" | |
POLL_INTERVAL = 300 # Check for new files every 5 minutes | |
# Ensure local directories exist | |
os.makedirs(LOCAL_DATA_DIR, exist_ok=True) | |
def initialize_database(): | |
"""Create and initialize the SQLite database to track transferred files.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
# Create tables if they don't exist | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS transferred_files ( | |
id INTEGER PRIMARY KEY, | |
filename TEXT UNIQUE, | |
remote_size INTEGER, | |
remote_mtime REAL, | |
transfer_time TIMESTAMP, | |
checksum TEXT | |
) | |
''') | |
# Table for tracking last successful CSV/PKL transfer | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS transfer_state ( | |
key TEXT PRIMARY KEY, | |
value TEXT | |
) | |
''') | |
conn.commit() | |
conn.close() | |
def is_file_transferred(filename, remote_size, remote_mtime): | |
"""Check if a file has already been transferred with the same size and mtime.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute( | |
"SELECT 1 FROM transferred_files WHERE filename = ? AND remote_size = ? AND remote_mtime = ?", | |
(filename, remote_size, remote_mtime) | |
) | |
result = cursor.fetchone() is not None | |
conn.close() | |
return result | |
def mark_file_transferred(filename, remote_size, remote_mtime, checksum): | |
"""Mark a file as successfully transferred.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute( | |
"""INSERT OR REPLACE INTO transferred_files | |
(filename, remote_size, remote_mtime, transfer_time, checksum) | |
VALUES (?, ?, ?, ?, ?)""", | |
(filename, remote_size, remote_mtime, datetime.now().isoformat(), checksum) | |
) | |
conn.commit() | |
conn.close() | |
def update_transfer_state(key, value): | |
"""Update the transfer state for a key.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT OR REPLACE INTO transfer_state (key, value) VALUES (?, ?)", | |
(key, value) | |
) | |
conn.commit() | |
conn.close() | |
def get_transfer_state(key): | |
"""Get the transfer state for a key.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute("SELECT value FROM transfer_state WHERE key = ?", (key,)) | |
result = cursor.fetchone() | |
conn.close() | |
return result[0] if result else None | |
def calculate_checksum(file_path): | |
"""Calculate MD5 checksum of a file.""" | |
md5 = hashlib.md5() | |
with open(file_path, 'rb') as f: | |
for chunk in iter(lambda: f.read(4096), b''): | |
md5.update(chunk) | |
return md5.hexdigest() | |
def create_ssh_client(): | |
"""Create and return an SSH client connected to the remote server.""" | |
client = paramiko.SSHClient() | |
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
# Expand the key path | |
key_path = os.path.expanduser(REMOTE_KEY_PATH) | |
try: | |
key = paramiko.RSAKey.from_private_key_file(key_path) | |
client.connect( | |
hostname=REMOTE_HOST, | |
username=REMOTE_USER, | |
pkey=key | |
) | |
logger.info(f"Successfully connected to {REMOTE_USER}@{REMOTE_HOST}") | |
return client | |
except Exception as e: | |
logger.error(f"Failed to connect to {REMOTE_HOST}: {str(e)}") | |
raise | |
def safe_transfer_file(sftp, remote_path, local_path): | |
""" | |
Transfer a file safely using a temporary file and rename. | |
Returns the checksum of the transferred file. | |
""" | |
# Create a temporary file for download | |
temp_file = local_path + ".tmp" | |
try: | |
# Transfer to temporary file | |
sftp.get(remote_path, temp_file) | |
# Calculate checksum | |
checksum = calculate_checksum(temp_file) | |
# Rename to final destination | |
os.rename(temp_file, local_path) | |
logger.info(f"Successfully transferred {remote_path} to {local_path}") | |
return checksum | |
except Exception as e: | |
logger.error(f"Error transferring {remote_path}: {str(e)}") | |
# Clean up temp file if it exists | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
raise | |
def is_file_stable(sftp, remote_path, wait_time=30): | |
""" | |
Check if a file is stable (not being written to) by comparing its size | |
before and after a short wait period. | |
""" | |
try: | |
# Get initial stats | |
initial_stat = sftp.stat(remote_path) | |
initial_size = initial_stat.st_size | |
# Wait a bit | |
time.sleep(wait_time) | |
# Get updated stats | |
updated_stat = sftp.stat(remote_path) | |
updated_size = updated_stat.st_size | |
# File is stable if size hasn't changed | |
is_stable = initial_size == updated_size | |
if not is_stable: | |
logger.info(f"File {remote_path} is still being written to (size changed from {initial_size} to {updated_size})") | |
return is_stable, updated_stat | |
except Exception as e: | |
logger.error(f"Error checking if {remote_path} is stable: {str(e)}") | |
return False, None | |
def transfer_tar_files(sftp): | |
"""Transfer all record_*.tar files that haven't been transferred yet.""" | |
transferred_count = 0 | |
try: | |
# List all tar files | |
tar_pattern = re.compile(r'record_.*\.tar$') | |
remote_files = sftp.listdir(REMOTE_DATA_DIR) | |
tar_files = [f for f in remote_files if tar_pattern.match(f)] | |
logger.info(f"Found {len(tar_files)} TAR files on remote server") | |
for tar_file in tar_files: | |
remote_path = os.path.join(REMOTE_DATA_DIR, tar_file) | |
local_path = os.path.join(LOCAL_DATA_DIR, tar_file) | |
# Get file stats | |
try: | |
stat = sftp.stat(remote_path) | |
except FileNotFoundError: | |
logger.warning(f"File {remote_path} disappeared, skipping") | |
continue | |
# Skip if already transferred with same size and mtime | |
if is_file_transferred(tar_file, stat.st_size, stat.st_mtime): | |
logger.debug(f"Skipping already transferred file: {tar_file}") | |
continue | |
# Check if file is stable (not being written to) | |
is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
if not is_stable: | |
logger.info(f"Skipping unstable file: {tar_file}") | |
continue | |
# Transfer the file | |
try: | |
checksum = safe_transfer_file(sftp, remote_path, local_path) | |
mark_file_transferred(tar_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
transferred_count += 1 | |
except Exception as e: | |
logger.error(f"Failed to transfer {tar_file}: {str(e)}") | |
continue | |
logger.info(f"Transferred {transferred_count} new TAR files") | |
return transferred_count | |
except Exception as e: | |
logger.error(f"Error in transfer_tar_files: {str(e)}") | |
return 0 | |
def transfer_pkl_file(sftp): | |
"""Transfer the PKL file if it hasn't been transferred yet or has changed.""" | |
pkl_file = "image_action_mapping_with_key_states.pkl" | |
remote_path = os.path.join(REMOTE_DATA_DIR, pkl_file) | |
local_path = os.path.join(LOCAL_DATA_DIR, pkl_file) | |
try: | |
# Check if file exists | |
try: | |
stat = sftp.stat(remote_path) | |
except FileNotFoundError: | |
logger.warning(f"PKL file {remote_path} not found") | |
return False | |
# Skip if already transferred with same size and mtime | |
if is_file_transferred(pkl_file, stat.st_size, stat.st_mtime): | |
logger.debug(f"Skipping already transferred PKL file (unchanged)") | |
return True | |
# Check if file is stable | |
is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
if not is_stable: | |
logger.info(f"PKL file is still being written to, skipping") | |
return False | |
# Transfer the file | |
checksum = safe_transfer_file(sftp, remote_path, local_path) | |
mark_file_transferred(pkl_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
# Update state | |
update_transfer_state("last_pkl_transfer", datetime.now().isoformat()) | |
logger.info(f"Successfully transferred PKL file") | |
return True | |
except Exception as e: | |
logger.error(f"Error transferring PKL file: {str(e)}") | |
return False | |
def transfer_csv_file(sftp): | |
"""Transfer the CSV file if it hasn't been transferred yet or has changed.""" | |
csv_file = "train_dataset.target_frames.csv" | |
remote_path = os.path.join(REMOTE_DATA_DIR, csv_file) | |
local_path = os.path.join(LOCAL_DATA_DIR, csv_file) | |
try: | |
# Check if file exists | |
try: | |
stat = sftp.stat(remote_path) | |
except FileNotFoundError: | |
logger.warning(f"CSV file {remote_path} not found") | |
return False | |
# Skip if already transferred with same size and mtime | |
if is_file_transferred(csv_file, stat.st_size, stat.st_mtime): | |
logger.debug(f"Skipping already transferred CSV file (unchanged)") | |
return True | |
# Check if file is stable | |
is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
if not is_stable: | |
logger.info(f"CSV file is still being written to, skipping") | |
return False | |
# Transfer the file | |
checksum = safe_transfer_file(sftp, remote_path, local_path) | |
mark_file_transferred(csv_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
# Update state | |
update_transfer_state("last_csv_transfer", datetime.now().isoformat()) | |
logger.info(f"Successfully transferred CSV file") | |
return True | |
except Exception as e: | |
logger.error(f"Error transferring CSV file: {str(e)}") | |
return False | |
def transfer_padding_file(sftp): | |
"""Transfer the padding.npy file if it hasn't been transferred yet or has changed.""" | |
padding_file = "padding.npy" | |
remote_path = os.path.join(REMOTE_DATA_DIR, padding_file) | |
local_path = os.path.join(LOCAL_DATA_DIR, padding_file) | |
try: | |
# Check if file exists | |
try: | |
stat = sftp.stat(remote_path) | |
except FileNotFoundError: | |
logger.warning(f"Padding file {remote_path} not found") | |
return False | |
# Skip if already transferred with same size and mtime | |
if is_file_transferred(padding_file, stat.st_size, stat.st_mtime): | |
logger.debug(f"Skipping already transferred padding file (unchanged)") | |
return True | |
# Check if file is stable | |
is_stable, updated_stat = is_file_stable(sftp, remote_path) | |
if not is_stable: | |
logger.info(f"Padding file is still being written to, skipping") | |
return False | |
# Transfer the file | |
checksum = safe_transfer_file(sftp, remote_path, local_path) | |
mark_file_transferred(padding_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
# Update state | |
update_transfer_state("last_padding_transfer", datetime.now().isoformat()) | |
logger.info(f"Successfully transferred padding.npy file") | |
return True | |
except Exception as e: | |
logger.error(f"Error transferring padding file: {str(e)}") | |
return False | |
def run_transfer_cycle(): | |
"""Run a complete transfer cycle with time-based consistency.""" | |
client = None | |
try: | |
# Connect to the remote server | |
client = create_ssh_client() | |
sftp = client.open_sftp() | |
# Step 0: Get a snapshot of all files with their timestamps | |
# This creates a consistent view of the remote directory at this point in time | |
logger.info("Taking snapshot of remote directory state") | |
remote_files = {} | |
for filename in sftp.listdir(REMOTE_DATA_DIR): | |
try: | |
file_path = os.path.join(REMOTE_DATA_DIR, filename) | |
stat = sftp.stat(file_path) | |
remote_files[filename] = { | |
'size': stat.st_size, | |
'mtime': stat.st_mtime, | |
'path': file_path | |
} | |
except Exception as e: | |
logger.warning(f"Could not stat file {filename}: {str(e)}") | |
logger.info(f"Found {len(remote_files)} files in remote directory snapshot") | |
# Step 1: Transfer padding.npy file if needed | |
if "padding.npy" in remote_files: | |
file_info = remote_files["padding.npy"] | |
if not is_file_transferred("padding.npy", file_info['size'], file_info['mtime']): | |
# Check stability | |
is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
if is_stable: | |
local_path = os.path.join(LOCAL_DATA_DIR, "padding.npy") | |
checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
mark_file_transferred("padding.npy", updated_stat.st_size, updated_stat.st_mtime, checksum) | |
logger.info("Successfully transferred padding.npy file") | |
else: | |
logger.warning("Padding file is still being written, skipping") | |
else: | |
logger.warning("padding.npy not found in remote directory") | |
# Step 2: Transfer TAR files from the snapshot | |
tar_pattern = re.compile(r'record_.*\.tar$') | |
tar_files = {name: info for name, info in remote_files.items() if tar_pattern.match(name)} | |
logger.info(f"Found {len(tar_files)} TAR files in snapshot") | |
tar_count = 0 | |
for tar_file, file_info in tar_files.items(): | |
# Skip if already transferred with same size and mtime | |
if is_file_transferred(tar_file, file_info['size'], file_info['mtime']): | |
logger.debug(f"Skipping already transferred file: {tar_file}") | |
continue | |
# Check if file is stable | |
is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
if not is_stable: | |
logger.info(f"Skipping unstable file: {tar_file}") | |
continue | |
# Transfer the file | |
try: | |
local_path = os.path.join(LOCAL_DATA_DIR, tar_file) | |
checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
mark_file_transferred(tar_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
tar_count += 1 | |
except Exception as e: | |
logger.error(f"Failed to transfer {tar_file}: {str(e)}") | |
logger.info(f"Transferred {tar_count} new TAR files from snapshot") | |
# Step 3: Transfer PKL file from the snapshot | |
pkl_file = "image_action_mapping_with_key_states.pkl" | |
if pkl_file in remote_files: | |
file_info = remote_files[pkl_file] | |
# Only transfer if needed | |
if not is_file_transferred(pkl_file, file_info['size'], file_info['mtime']): | |
is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
if is_stable: | |
local_path = os.path.join(LOCAL_DATA_DIR, pkl_file) | |
checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
mark_file_transferred(pkl_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
update_transfer_state("last_pkl_transfer", datetime.now().isoformat()) | |
logger.info("Successfully transferred PKL file from snapshot") | |
pkl_success = True | |
else: | |
logger.warning("PKL file is still being written, skipping") | |
pkl_success = False | |
else: | |
logger.debug("PKL file unchanged, skipping") | |
pkl_success = True | |
else: | |
logger.warning("PKL file not found in snapshot") | |
pkl_success = False | |
# Step 4: Transfer CSV file from the snapshot (only if PKL succeeded) | |
csv_file = "train_dataset.target_frames.csv" | |
if pkl_success and csv_file in remote_files: | |
file_info = remote_files[csv_file] | |
# Only transfer if needed | |
if not is_file_transferred(csv_file, file_info['size'], file_info['mtime']): | |
is_stable, updated_stat = is_file_stable(sftp, file_info['path']) | |
if is_stable: | |
local_path = os.path.join(LOCAL_DATA_DIR, csv_file) | |
checksum = safe_transfer_file(sftp, file_info['path'], local_path) | |
mark_file_transferred(csv_file, updated_stat.st_size, updated_stat.st_mtime, checksum) | |
update_transfer_state("last_csv_transfer", datetime.now().isoformat()) | |
logger.info("Successfully transferred CSV file from snapshot") | |
csv_success = True | |
else: | |
logger.warning("CSV file is still being written, skipping") | |
csv_success = False | |
else: | |
logger.debug("CSV file unchanged, skipping") | |
csv_success = True | |
else: | |
if not pkl_success: | |
logger.warning("Skipping CSV transfer because PKL transfer failed") | |
else: | |
logger.warning("CSV file not found in snapshot") | |
csv_success = False | |
return tar_count > 0 or pkl_success or csv_success | |
except Exception as e: | |
logger.error(f"Error in transfer cycle: {str(e)}") | |
return False | |
finally: | |
if client: | |
client.close() | |
def main(): | |
"""Main function for the data transfer script.""" | |
logger.info("Starting data transfer script") | |
# Initialize the database | |
initialize_database() | |
try: | |
while True: | |
logger.info("Starting new transfer cycle") | |
changes = run_transfer_cycle() | |
if changes: | |
logger.info("Transfer cycle completed with new files transferred") | |
else: | |
logger.info("Transfer cycle completed with no changes") | |
logger.info(f"Sleeping for {POLL_INTERVAL} seconds before next check") | |
time.sleep(POLL_INTERVAL) | |
except KeyboardInterrupt: | |
logger.info("Script terminated by user") | |
except Exception as e: | |
logger.error(f"Unhandled exception: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() |