fisherman611 commited on
Commit
a4a39ab
·
verified ·
1 Parent(s): c70f97e

Create utils/split_data.py

Browse files
Files changed (1) hide show
  1. utils/split_data.py +76 -0
utils/split_data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import shutil
4
+ from sklearn.model_selection import train_test_split
5
+ from tqdm.auto import tqdm
6
+
7
+ df_2014 = pd.read_csv('data/CROHME/2014/caption.txt', sep='\t', header=None, names=['filenames', 'captions'])
8
+ df_2016 = pd.read_csv('data/CROHME/2016/caption.txt', sep='\t', header=None, names=['filenames', 'captions'])
9
+ df_2019 = pd.read_csv('data/CROHME/2019/caption.txt', sep='\t', header=None, names=['filenames', 'captions'])
10
+ df_train = pd.read_csv('data/CROHME/train/caption.txt', sep='\t', header=None, names=['filenames', 'captions'])
11
+
12
+ data = pd.concat(
13
+ [
14
+ df_2014,
15
+ df_2016,
16
+ df_2019,
17
+ df_train
18
+ ]
19
+ )
20
+
21
+ # First, split off 10% of the data to train and test sets
22
+ train, test = train_test_split(data, test_size=0.1, random_state=42)
23
+
24
+ # Second, split off 10% of the training data to train and validation sets
25
+ train, val = train_test_split(train, test_size=0.1, random_state=42)
26
+
27
+ print("Train shape:", train.shape)
28
+ print("Test shape:", test.shape)
29
+ print("Validation shape:", val.shape)
30
+
31
+ train_filenames = train['filenames'].tolist()
32
+ train_captions = train['captions'].tolist()
33
+
34
+ test_filenames = test['filenames'].tolist()
35
+ test_captions = test['captions'].tolist()
36
+
37
+ val_filenames = val['filenames'].tolist()
38
+ val_captions = val['captions'].tolist()
39
+
40
+ # Extract captions.txt for each split
41
+ with open('data/CROHME_splitted/train/caption.txt', 'w', encoding='utf-8') as f:
42
+ for filename, caption in zip(train_filenames, train_captions):
43
+ f.write(f"{filename}\t{caption}\n")
44
+
45
+ with open('data/CROHME_splitted/test/caption.txt', 'w', encoding='utf-8') as f:
46
+ for filename, caption in zip(test_filenames, test_captions):
47
+ f.write(f"{filename}\t{caption}\n")
48
+
49
+ with open('data/CROHME_splitted/val/caption.txt', 'w', encoding='utf-8') as f:
50
+ for filename, caption in zip(val_filenames, val_captions):
51
+ f.write(f"{filename}\t{caption}\n")
52
+
53
+
54
+ IMAGES_DIR = 'data/images'
55
+ TRAIN_DIR = 'data/CROHME_splitted/train/img'
56
+ TEST_DIR = 'data/CROHME_splitted/test/img'
57
+ VAL_DIR = 'data/CROHME_splitted/val/img'
58
+
59
+ os.makedirs(TRAIN_DIR, exist_ok=True)
60
+ os.makedirs(TEST_DIR, exist_ok=True)
61
+ os.makedirs(VAL_DIR, exist_ok=True)
62
+
63
+ for train_filename in tqdm(train_filenames, desc="Copying train images"):
64
+ src = os.path.join(IMAGES_DIR, train_filename) + '.bmp' # Ensure the file extension is correct
65
+ dst = os.path.join(TRAIN_DIR, train_filename) + '.bmp'
66
+ shutil.copy(src, dst)
67
+
68
+ for test_filename in tqdm(test_filenames, desc="Copying test images"):
69
+ src = os.path.join(IMAGES_DIR, test_filename) + '.bmp'
70
+ dst = os.path.join(TEST_DIR, test_filename) + '.bmp'
71
+ shutil.copy(src, dst)
72
+
73
+ for val_filename in tqdm(val_filenames, desc="Copying validation images"):
74
+ src = os.path.join(IMAGES_DIR, val_filename) + '.bmp'
75
+ dst = os.path.join(VAL_DIR, val_filename) + '.bmp'
76
+ shutil.copy(src, dst)