File size: 9,647 Bytes
7c5aa99 f94c5ea c259678 0430da2 e42fc15 5c35386 0430da2 4a4435c 0430da2 f94c5ea 4a4435c 5d41434 1e3138f c259678 f94c5ea 5d41434 c259678 3d37920 c259678 9dd78b5 c259678 9dd78b5 0430da2 1e3138f 5c35386 c259678 5c35386 c259678 5c35386 c259678 5c35386 a9ae246 c259678 a9ae246 c259678 a9ae246 c259678 a9ae246 c259678 9dd78b5 4410500 d6a5933 c259678 f94c5ea c259678 9dd78b5 c259678 9dd78b5 c259678 5c35386 d6a5933 c259678 3d37920 c259678 3d37920 f94c5ea c259678 f94c5ea c259678 5c35386 4a4435c 3d37920 c259678 5d41434 c259678 5d41434 c259678 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
import requests
import time
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = os.getcwd() # Χρησιμοποιεί τον τρέχοντα φάκελο
#TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000 # Αυξημένο όριο δειγμάτων
DEFAULT_CHUNK_SIZE = 200000 # Μεγαλύτερο chunk size
BATCH_SIZE = 1000 # Μέγεθος batch για φόρτωση δεδομένων
NUM_WORKERS = 4 # Αριθμός workers για πολυνηματική επεξεργασία
# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False
# Καταγραφή εκκίνησης
startup_log = f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n"
print(startup_log)
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων με ομαδοποίηση."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
batch = "\n".join(texts) + "\n"
f.write(batch)
def create_iterator(dataset_name, configs, split):
"""Βελτιωμένο iterator με batch φόρτωση και caching."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True,
cache_dir="./dataset_cache" # Ενεργοποίηση cache
)
# Φόρτωση δεδομένων σε batches
while True:
batch = list(dataset.take(BATCH_SIZE))
if not batch:
break
dataset = dataset.skip(BATCH_SIZE)
# Πολυνηματική επεξεργασία batch
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
processed_texts = list(executor.map(process_example, batch))
yield from filter(None, processed_texts)
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης: {config}: {e}")
def process_example(example):
"""Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
try:
text = example.get('text', '').strip()
if text and detect(text) in ['el', 'en']: # Φιλτράρισμα γλώσσας
return text
return None
except:
return None
def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
"""Βελτιωμένη συλλογή δεδομένων με μεγάλα chunks."""
global STOP_COLLECTION
STOP_COLLECTION = False
total_processed = len(load_checkpoint())
progress_messages = [
f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}",
f"⚙️ Ρυθμίσεις: Chunk Size={chunk_size}, Workers={NUM_WORKERS}"
]
dataset_iterator = create_iterator(dataset_name, configs, split)
chunk = []
while not STOP_COLLECTION and total_processed < max_samples:
try:
# Φόρτωση chunk
while len(chunk) < chunk_size:
text = next(dataset_iterator)
if text:
chunk.append(text)
total_processed += 1
if total_processed >= max_samples:
break
# Αποθήκευση chunk
if chunk:
append_to_checkpoint(chunk)
progress_messages.append(
f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})"
)
chunk = []
# Εκκαθάριση μνήμης
gc.collect()
except StopIteration:
progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
break
except Exception as e:
progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
break
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Βελτιωμένη εκπαίδευση tokenizer με χρήση cache."""
print("🚀 Εκκίνηση εκπαίδευσης...")
all_texts = load_checkpoint()
# Παράλληλη επεξεργασία για εκπαίδευση
tokenizer = train_tokenizer(
all_texts,
vocab_size=vocab_size,
min_frequency=min_freq,
output_dir=TOKENIZER_DIR,
num_threads=NUM_WORKERS # Παράλληλη επεξεργασία
)
# Φόρτωση και δοκιμή tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Δημιουργία γραφήματος
fig, ax = plt.subplots()
ax.hist([len(t) for t in encoded.tokens], bins=20)
ax.set_xlabel('Μήκος Token')
ax.set_ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return ("✅ Εκπαίδευση ολοκληρώθηκε!", decoded, Image.open(img_buffer))
print(f"Ο tokenizer αποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}")
def analyze_checkpoint():
"""Νέα λειτουργία ανάλυσης δεδομένων."""
texts = load_checkpoint()
if not texts:
return "Δεν βρέθηκαν δεδομένα για ανάλυση."
# Βασική στατιστική
total_chars = sum(len(t) for t in texts)
avg_length = total_chars / len(texts) if texts else 0
# Ανάλυση γλώσσας
languages = {}
for t in texts[:1000]: # Δειγματοληψία για ταχύτητα
try:
lang = detect(t)
languages[lang] = languages.get(lang, 0) + 1
except:
continue
report = [
f"📊 Σύνολο δειγμάτων: {len(texts)}",
f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
"🌍 Γλώσσες (δείγμα 1000):",
*[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
]
return "\n".join(report)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Βελτιωμένος Wikipedia Tokenizer Trainer")
with gr.Row():
with gr.Column(scale=2):
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
split = gr.Dropdown(["train"], value="train", label="Split")
chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
vocab_size = gr.Slider(20000, 200000, value=50000, step=10000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Maximum Samples")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
restart_btn = gr.Button("Restart")
analyze_btn = gr.Button("Analyze Data")
train_btn = gr.Button("Train Tokenizer", variant="primary")
with gr.Column(scale=3):
progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
gr.Markdown("### Αποτελέσματα")
decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
token_distribution = gr.Image(label="Κατανομή Tokens")
# Event handlers
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
stop_btn.click(lambda: "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
restart_btn.click(lambda: "🔄 Επαναφορά...", None, progress).then(restart_collection, None, progress)
analyze_btn.click(analyze_checkpoint, None, progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.queue(concurrency_count=4).launch() |