da03 commited on
Commit
e552755
·
1 Parent(s): a92ddb8
Files changed (1) hide show
  1. sync_train_dataset.py +401 -0
sync_train_dataset.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sys
4
+ import time
5
+ import logging
6
+ import paramiko
7
+ import hashlib
8
+ import tempfile
9
+ from datetime import datetime
10
+ import sqlite3
11
+ import re
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(levelname)s - %(message)s',
17
+ handlers=[
18
+ logging.FileHandler("data_transfer.log"),
19
+ logging.StreamHandler()
20
+ ]
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Configuration
25
+ REMOTE_HOST = "86.38.238.117"
26
+ REMOTE_USER = "root" # Replace with your actual username
27
+ REMOTE_KEY_PATH = "~/.ssh/id_rsa" # Replace with path to your SSH key
28
+ REMOTE_DATA_DIR = "~/neuralos-demo/train_dataset_encoded_online" # Replace with actual path
29
+ LOCAL_DATA_DIR = "./train_dataset_encoded_online" # Local destination
30
+ DB_FILE = "transfer_state.db"
31
+ POLL_INTERVAL = 300 # Check for new files every 5 minutes
32
+
33
+ # Ensure local directories exist
34
+ os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
35
+
36
+
37
+ def initialize_database():
38
+ """Create and initialize the SQLite database to track transferred files."""
39
+ conn = sqlite3.connect(DB_FILE)
40
+ cursor = conn.cursor()
41
+
42
+ # Create tables if they don't exist
43
+ cursor.execute('''
44
+ CREATE TABLE IF NOT EXISTS transferred_files (
45
+ id INTEGER PRIMARY KEY,
46
+ filename TEXT UNIQUE,
47
+ remote_size INTEGER,
48
+ remote_mtime REAL,
49
+ transfer_time TIMESTAMP,
50
+ checksum TEXT
51
+ )
52
+ ''')
53
+
54
+ # Table for tracking last successful CSV/PKL transfer
55
+ cursor.execute('''
56
+ CREATE TABLE IF NOT EXISTS transfer_state (
57
+ key TEXT PRIMARY KEY,
58
+ value TEXT
59
+ )
60
+ ''')
61
+
62
+ conn.commit()
63
+ conn.close()
64
+
65
+
66
+ def is_file_transferred(filename, remote_size, remote_mtime):
67
+ """Check if a file has already been transferred with the same size and mtime."""
68
+ conn = sqlite3.connect(DB_FILE)
69
+ cursor = conn.cursor()
70
+
71
+ cursor.execute(
72
+ "SELECT 1 FROM transferred_files WHERE filename = ? AND remote_size = ? AND remote_mtime = ?",
73
+ (filename, remote_size, remote_mtime)
74
+ )
75
+ result = cursor.fetchone() is not None
76
+
77
+ conn.close()
78
+ return result
79
+
80
+
81
+ def mark_file_transferred(filename, remote_size, remote_mtime, checksum):
82
+ """Mark a file as successfully transferred."""
83
+ conn = sqlite3.connect(DB_FILE)
84
+ cursor = conn.cursor()
85
+
86
+ cursor.execute(
87
+ """INSERT OR REPLACE INTO transferred_files
88
+ (filename, remote_size, remote_mtime, transfer_time, checksum)
89
+ VALUES (?, ?, ?, ?, ?)""",
90
+ (filename, remote_size, remote_mtime, datetime.now().isoformat(), checksum)
91
+ )
92
+
93
+ conn.commit()
94
+ conn.close()
95
+
96
+
97
+ def update_transfer_state(key, value):
98
+ """Update the transfer state for a key."""
99
+ conn = sqlite3.connect(DB_FILE)
100
+ cursor = conn.cursor()
101
+
102
+ cursor.execute(
103
+ "INSERT OR REPLACE INTO transfer_state (key, value) VALUES (?, ?)",
104
+ (key, value)
105
+ )
106
+
107
+ conn.commit()
108
+ conn.close()
109
+
110
+
111
+ def get_transfer_state(key):
112
+ """Get the transfer state for a key."""
113
+ conn = sqlite3.connect(DB_FILE)
114
+ cursor = conn.cursor()
115
+
116
+ cursor.execute("SELECT value FROM transfer_state WHERE key = ?", (key,))
117
+ result = cursor.fetchone()
118
+
119
+ conn.close()
120
+ return result[0] if result else None
121
+
122
+
123
+ def calculate_checksum(file_path):
124
+ """Calculate MD5 checksum of a file."""
125
+ md5 = hashlib.md5()
126
+ with open(file_path, 'rb') as f:
127
+ for chunk in iter(lambda: f.read(4096), b''):
128
+ md5.update(chunk)
129
+ return md5.hexdigest()
130
+
131
+
132
+ def create_ssh_client():
133
+ """Create and return an SSH client connected to the remote server."""
134
+ client = paramiko.SSHClient()
135
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
136
+
137
+ # Expand the key path
138
+ key_path = os.path.expanduser(REMOTE_KEY_PATH)
139
+
140
+ try:
141
+ key = paramiko.RSAKey.from_private_key_file(key_path)
142
+ client.connect(
143
+ hostname=REMOTE_HOST,
144
+ username=REMOTE_USER,
145
+ pkey=key
146
+ )
147
+ logger.info(f"Successfully connected to {REMOTE_USER}@{REMOTE_HOST}")
148
+ return client
149
+ except Exception as e:
150
+ logger.error(f"Failed to connect to {REMOTE_HOST}: {str(e)}")
151
+ raise
152
+
153
+
154
+ def safe_transfer_file(sftp, remote_path, local_path):
155
+ """
156
+ Transfer a file safely using a temporary file and rename.
157
+ Returns the checksum of the transferred file.
158
+ """
159
+ # Create a temporary file for download
160
+ temp_file = local_path + ".tmp"
161
+
162
+ try:
163
+ # Transfer to temporary file
164
+ sftp.get(remote_path, temp_file)
165
+
166
+ # Calculate checksum
167
+ checksum = calculate_checksum(temp_file)
168
+
169
+ # Rename to final destination
170
+ os.rename(temp_file, local_path)
171
+ logger.info(f"Successfully transferred {remote_path} to {local_path}")
172
+
173
+ return checksum
174
+ except Exception as e:
175
+ logger.error(f"Error transferring {remote_path}: {str(e)}")
176
+ # Clean up temp file if it exists
177
+ if os.path.exists(temp_file):
178
+ os.remove(temp_file)
179
+ raise
180
+
181
+
182
+ def is_file_stable(sftp, remote_path, wait_time=30):
183
+ """
184
+ Check if a file is stable (not being written to) by comparing its size
185
+ before and after a short wait period.
186
+ """
187
+ try:
188
+ # Get initial stats
189
+ initial_stat = sftp.stat(remote_path)
190
+ initial_size = initial_stat.st_size
191
+
192
+ # Wait a bit
193
+ time.sleep(wait_time)
194
+
195
+ # Get updated stats
196
+ updated_stat = sftp.stat(remote_path)
197
+ updated_size = updated_stat.st_size
198
+
199
+ # File is stable if size hasn't changed
200
+ is_stable = initial_size == updated_size
201
+
202
+ if not is_stable:
203
+ logger.info(f"File {remote_path} is still being written to (size changed from {initial_size} to {updated_size})")
204
+
205
+ return is_stable, updated_stat
206
+ except Exception as e:
207
+ logger.error(f"Error checking if {remote_path} is stable: {str(e)}")
208
+ return False, None
209
+
210
+
211
+ def transfer_tar_files(sftp):
212
+ """Transfer all record_*.tar files that haven't been transferred yet."""
213
+ transferred_count = 0
214
+
215
+ try:
216
+ # List all tar files
217
+ tar_pattern = re.compile(r'record_.*\.tar$')
218
+ remote_files = sftp.listdir(REMOTE_DATA_DIR)
219
+ tar_files = [f for f in remote_files if tar_pattern.match(f)]
220
+
221
+ logger.info(f"Found {len(tar_files)} TAR files on remote server")
222
+
223
+ for tar_file in tar_files:
224
+ remote_path = os.path.join(REMOTE_DATA_DIR, tar_file)
225
+ local_path = os.path.join(LOCAL_DATA_DIR, tar_file)
226
+
227
+ # Get file stats
228
+ try:
229
+ stat = sftp.stat(remote_path)
230
+ except FileNotFoundError:
231
+ logger.warning(f"File {remote_path} disappeared, skipping")
232
+ continue
233
+
234
+ # Skip if already transferred with same size and mtime
235
+ if is_file_transferred(tar_file, stat.st_size, stat.st_mtime):
236
+ logger.debug(f"Skipping already transferred file: {tar_file}")
237
+ continue
238
+
239
+ # Check if file is stable (not being written to)
240
+ is_stable, updated_stat = is_file_stable(sftp, remote_path)
241
+ if not is_stable:
242
+ logger.info(f"Skipping unstable file: {tar_file}")
243
+ continue
244
+
245
+ # Transfer the file
246
+ try:
247
+ checksum = safe_transfer_file(sftp, remote_path, local_path)
248
+ mark_file_transferred(tar_file, updated_stat.st_size, updated_stat.st_mtime, checksum)
249
+ transferred_count += 1
250
+ except Exception as e:
251
+ logger.error(f"Failed to transfer {tar_file}: {str(e)}")
252
+ continue
253
+
254
+ logger.info(f"Transferred {transferred_count} new TAR files")
255
+ return transferred_count
256
+ except Exception as e:
257
+ logger.error(f"Error in transfer_tar_files: {str(e)}")
258
+ return 0
259
+
260
+
261
+ def transfer_pkl_file(sftp):
262
+ """Transfer the PKL file if it hasn't been transferred yet or has changed."""
263
+ pkl_file = "image_action_mapping_with_key_states.pkl"
264
+ remote_path = os.path.join(REMOTE_DATA_DIR, pkl_file)
265
+ local_path = os.path.join(LOCAL_DATA_DIR, pkl_file)
266
+
267
+ try:
268
+ # Check if file exists
269
+ try:
270
+ stat = sftp.stat(remote_path)
271
+ except FileNotFoundError:
272
+ logger.warning(f"PKL file {remote_path} not found")
273
+ return False
274
+
275
+ # Skip if already transferred with same size and mtime
276
+ if is_file_transferred(pkl_file, stat.st_size, stat.st_mtime):
277
+ logger.debug(f"Skipping already transferred PKL file (unchanged)")
278
+ return True
279
+
280
+ # Check if file is stable
281
+ is_stable, updated_stat = is_file_stable(sftp, remote_path)
282
+ if not is_stable:
283
+ logger.info(f"PKL file is still being written to, skipping")
284
+ return False
285
+
286
+ # Transfer the file
287
+ checksum = safe_transfer_file(sftp, remote_path, local_path)
288
+ mark_file_transferred(pkl_file, updated_stat.st_size, updated_stat.st_mtime, checksum)
289
+
290
+ # Update state
291
+ update_transfer_state("last_pkl_transfer", datetime.now().isoformat())
292
+
293
+ logger.info(f"Successfully transferred PKL file")
294
+ return True
295
+ except Exception as e:
296
+ logger.error(f"Error transferring PKL file: {str(e)}")
297
+ return False
298
+
299
+
300
+ def transfer_csv_file(sftp):
301
+ """Transfer the CSV file if it hasn't been transferred yet or has changed."""
302
+ csv_file = "train_dataset.target_frames.csv"
303
+ remote_path = os.path.join(REMOTE_DATA_DIR, csv_file)
304
+ local_path = os.path.join(LOCAL_DATA_DIR, csv_file)
305
+
306
+ try:
307
+ # Check if file exists
308
+ try:
309
+ stat = sftp.stat(remote_path)
310
+ except FileNotFoundError:
311
+ logger.warning(f"CSV file {remote_path} not found")
312
+ return False
313
+
314
+ # Skip if already transferred with same size and mtime
315
+ if is_file_transferred(csv_file, stat.st_size, stat.st_mtime):
316
+ logger.debug(f"Skipping already transferred CSV file (unchanged)")
317
+ return True
318
+
319
+ # Check if file is stable
320
+ is_stable, updated_stat = is_file_stable(sftp, remote_path)
321
+ if not is_stable:
322
+ logger.info(f"CSV file is still being written to, skipping")
323
+ return False
324
+
325
+ # Transfer the file
326
+ checksum = safe_transfer_file(sftp, remote_path, local_path)
327
+ mark_file_transferred(csv_file, updated_stat.st_size, updated_stat.st_mtime, checksum)
328
+
329
+ # Update state
330
+ update_transfer_state("last_csv_transfer", datetime.now().isoformat())
331
+
332
+ logger.info(f"Successfully transferred CSV file")
333
+ return True
334
+ except Exception as e:
335
+ logger.error(f"Error transferring CSV file: {str(e)}")
336
+ return False
337
+
338
+
339
+ def run_transfer_cycle():
340
+ """Run a complete transfer cycle."""
341
+ client = None
342
+ try:
343
+ # Connect to the remote server
344
+ client = create_ssh_client()
345
+ sftp = client.open_sftp()
346
+
347
+ # Step 1: Transfer TAR files
348
+ tar_count = transfer_tar_files(sftp)
349
+
350
+ # Step 2: Transfer PKL file (only if we have new TAR files or it's changed)
351
+ if tar_count > 0 or not get_transfer_state("last_pkl_transfer"):
352
+ pkl_success = transfer_pkl_file(sftp)
353
+ else:
354
+ pkl_success = True # Assume success if we didn't need to transfer
355
+
356
+ # Step 3: Transfer CSV file (only if PKL transfer succeeded)
357
+ if pkl_success:
358
+ csv_success = transfer_csv_file(sftp)
359
+ else:
360
+ logger.warning("Skipping CSV transfer because PKL transfer failed")
361
+ csv_success = False
362
+
363
+ return tar_count > 0 or pkl_success or csv_success
364
+ except Exception as e:
365
+ logger.error(f"Error in transfer cycle: {str(e)}")
366
+ return False
367
+ finally:
368
+ if client:
369
+ client.close()
370
+
371
+
372
+ def main():
373
+ """Main function for the data transfer script."""
374
+ logger.info("Starting data transfer script")
375
+
376
+ # Initialize the database
377
+ initialize_database()
378
+
379
+ try:
380
+ while True:
381
+ logger.info("Starting new transfer cycle")
382
+
383
+ changes = run_transfer_cycle()
384
+
385
+ if changes:
386
+ logger.info("Transfer cycle completed with new files transferred")
387
+ else:
388
+ logger.info("Transfer cycle completed with no changes")
389
+
390
+ logger.info(f"Sleeping for {POLL_INTERVAL} seconds before next check")
391
+ time.sleep(POLL_INTERVAL)
392
+
393
+ except KeyboardInterrupt:
394
+ logger.info("Script terminated by user")
395
+ except Exception as e:
396
+ logger.error(f"Unhandled exception: {str(e)}")
397
+ raise
398
+
399
+
400
+ if __name__ == "__main__":
401
+ main()