File size: 9,647 Bytes
7c5aa99
f94c5ea
c259678
0430da2
e42fc15
5c35386
0430da2
4a4435c
0430da2
f94c5ea
4a4435c
5d41434
1e3138f
 
c259678
f94c5ea
5d41434
 
 
c259678
 
 
 
3d37920
c259678
 
 
 
9dd78b5
c259678
9dd78b5
0430da2
1e3138f
 
 
 
5c35386
c259678
5c35386
 
 
 
 
 
c259678
5c35386
c259678
 
5c35386
a9ae246
c259678
a9ae246
c259678
a9ae246
 
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9ae246
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd78b5
4410500
d6a5933
c259678
 
 
 
 
 
f94c5ea
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd78b5
c259678
 
9dd78b5
c259678
 
5c35386
d6a5933
c259678
 
3d37920
c259678
 
 
 
 
 
 
 
 
 
 
3d37920
f94c5ea
 
c259678
 
 
 
 
 
f94c5ea
 
 
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c35386
4a4435c
3d37920
c259678
 
5d41434
c259678
 
 
 
 
 
5d41434
 
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
import requests
import time
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0

# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = os.getcwd()  # Χρησιμοποιεί τον τρέχοντα φάκελο
#TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000  # Αυξημένο όριο δειγμάτων
DEFAULT_CHUNK_SIZE = 200000  # Μεγαλύτερο chunk size
BATCH_SIZE = 1000  # Μέγεθος batch για φόρτωση δεδομένων
NUM_WORKERS = 4    # Αριθμός workers για πολυνηματική επεξεργασία

# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False

# Καταγραφή εκκίνησης
startup_log = f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n"
print(startup_log)

def load_checkpoint():
    """Φόρτωση δεδομένων από το checkpoint."""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
            return f.read().splitlines()
    return []

def append_to_checkpoint(texts):
    """Αποθήκευση δεδομένων με ομαδοποίηση."""
    with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
        batch = "\n".join(texts) + "\n"
        f.write(batch)

def create_iterator(dataset_name, configs, split):
    """Βελτιωμένο iterator με batch φόρτωση και caching."""
    configs_list = [c.strip() for c in configs.split(",") if c.strip()]
    
    for config in configs_list:
        try:
            dataset = load_dataset(
                dataset_name,
                name=config,
                split=split,
                streaming=True,
                cache_dir="./dataset_cache"  # Ενεργοποίηση cache
            )
            
            # Φόρτωση δεδομένων σε batches
            while True:
                batch = list(dataset.take(BATCH_SIZE))
                if not batch:
                    break
                dataset = dataset.skip(BATCH_SIZE)
                
                # Πολυνηματική επεξεργασία batch
                with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
                    processed_texts = list(executor.map(process_example, batch))
                
                yield from filter(None, processed_texts)
                
        except Exception as e:
            print(f"⚠️ Σφάλμα φόρτωσης: {config}: {e}")

def process_example(example):
    """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
    try:
        text = example.get('text', '').strip()
        if text and detect(text) in ['el', 'en']:  # Φιλτράρισμα γλώσσας
            return text
        return None
    except:
        return None

def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
    """Βελτιωμένη συλλογή δεδομένων με μεγάλα chunks."""
    global STOP_COLLECTION
    STOP_COLLECTION = False
    total_processed = len(load_checkpoint())
    
    progress_messages = [
        f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}",
        f"⚙️ Ρυθμίσεις: Chunk Size={chunk_size}, Workers={NUM_WORKERS}"
    ]
    
    dataset_iterator = create_iterator(dataset_name, configs, split)
    chunk = []
    
    while not STOP_COLLECTION and total_processed < max_samples:
        try:
            # Φόρτωση chunk
            while len(chunk) < chunk_size:
                text = next(dataset_iterator)
                if text:
                    chunk.append(text)
                    total_processed += 1
                    if total_processed >= max_samples:
                        break
        
            # Αποθήκευση chunk
            if chunk:
                append_to_checkpoint(chunk)
                progress_messages.append(
                    f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})"
                )
                chunk = []
                
                # Εκκαθάριση μνήμης
                gc.collect()
                
        except StopIteration:
            progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
            break
        except Exception as e:
            progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
            break
    
    return "\n".join(progress_messages)

def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
    """Βελτιωμένη εκπαίδευση tokenizer με χρήση cache."""
    print("🚀 Εκκίνηση εκπαίδευσης...")
    all_texts = load_checkpoint()
    
    # Παράλληλη επεξεργασία για εκπαίδευση
    tokenizer = train_tokenizer(
        all_texts,
        vocab_size=vocab_size,
        min_frequency=min_freq,
        output_dir=TOKENIZER_DIR,
        num_threads=NUM_WORKERS  # Παράλληλη επεξεργασία
    )
    
    # Φόρτωση και δοκιμή tokenizer
    trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
    encoded = trained_tokenizer.encode(test_text)
    decoded = trained_tokenizer.decode(encoded.ids)
    
    # Δημιουργία γραφήματος
    fig, ax = plt.subplots()
    ax.hist([len(t) for t in encoded.tokens], bins=20)
    ax.set_xlabel('Μήκος Token')
    ax.set_ylabel('Συχνότητα')
    img_buffer = BytesIO()
    plt.savefig(img_buffer, format='png')
    plt.close()
    
    return ("✅ Εκπαίδευση ολοκληρώθηκε!", decoded, Image.open(img_buffer))
    print(f"Ο tokenizer αποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}")

def analyze_checkpoint():
    """Νέα λειτουργία ανάλυσης δεδομένων."""
    texts = load_checkpoint()
    if not texts:
        return "Δεν βρέθηκαν δεδομένα για ανάλυση."
    
    # Βασική στατιστική
    total_chars = sum(len(t) for t in texts)
    avg_length = total_chars / len(texts) if texts else 0
    
    # Ανάλυση γλώσσας
    languages = {}
    for t in texts[:1000]:  # Δειγματοληψία για ταχύτητα
        try:
            lang = detect(t)
            languages[lang] = languages.get(lang, 0) + 1
        except:
            continue
    
    report = [
        f"📊 Σύνολο δειγμάτων: {len(texts)}",
        f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
        "🌍 Γλώσσες (δείγμα 1000):",
        *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
    ]
    
    return "\n".join(report)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Βελτιωμένος Wikipedia Tokenizer Trainer")
    
    with gr.Row():
        with gr.Column(scale=2):
            dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
            configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
            split = gr.Dropdown(["train"], value="train", label="Split")
            chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
            vocab_size = gr.Slider(20000, 200000, value=50000, step=10000, label="Vocabulary Size")
            min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
            test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
            max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Maximum Samples")
            
            with gr.Row():
                start_btn = gr.Button("Start", variant="primary")
                stop_btn = gr.Button("Stop", variant="stop")
                restart_btn = gr.Button("Restart")
            
            analyze_btn = gr.Button("Analyze Data")
            train_btn = gr.Button("Train Tokenizer", variant="primary")
        
        with gr.Column(scale=3):
            progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
            gr.Markdown("### Αποτελέσματα")
            decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
            token_distribution = gr.Image(label="Κατανομή Tokens")

    # Event handlers
    start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
    stop_btn.click(lambda: "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
    restart_btn.click(lambda: "🔄 Επαναφορά...", None, progress).then(restart_collection, None, progress)
    analyze_btn.click(analyze_checkpoint, None, progress)
    train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
                   [progress, decoded_text, token_distribution])

demo.queue(concurrency_count=4).launch()