|
|
|
import os |
|
import gradio as gr |
|
import requests |
|
from io import BytesIO |
|
import matplotlib.pyplot as plt |
|
from datasets import load_dataset |
|
from train_tokenizer import train_tokenizer |
|
from tokenizers import Tokenizer |
|
from langdetect import detect, DetectorFactory |
|
|
|
|
|
DetectorFactory.seed = 0 |
|
|
|
|
|
CHECKPOINT_FILE = "checkpoint.txt" |
|
TOKENIZER_DIR = "tokenizer_model" |
|
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") |
|
CHUNK_SIZE = 1000 |
|
MAX_SAMPLES = 3000000 |
|
|
|
|
|
STOP_COLLECTION = False |
|
|
|
def fetch_splits(dataset_name): |
|
"""Ανάκτηση των splits του dataset από το Hugging Face.""" |
|
try: |
|
response = requests.get(f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", timeout=10) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
splits_info = {} |
|
for split in data['splits']: |
|
config = split['config'] |
|
split_name = split['split'] |
|
if config not in splits_info: |
|
splits_info[config] = [] |
|
splits_info[config].append(split_name) |
|
|
|
return { |
|
"splits": splits_info, |
|
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}" |
|
} |
|
except Exception as e: |
|
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}") |
|
|
|
def create_iterator(dataset_name, configs, split): |
|
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator.""" |
|
configs_list = [c.strip() for c in configs.split(",") if c.strip()] |
|
for config in configs_list: |
|
try: |
|
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True) |
|
for example in dataset: |
|
text = example.get('text', '') |
|
if text: |
|
yield text |
|
except Exception as e: |
|
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}") |
|
|
|
def append_to_checkpoint(texts): |
|
"""Αποθήκευση δεδομένων στο αρχείο checkpoint.""" |
|
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: |
|
for t in texts: |
|
f.write(t + "\n") |
|
|
|
def load_checkpoint(): |
|
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει.""" |
|
if os.path.exists(CHECKPOINT_FILE): |
|
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
|
return f.read().splitlines() |
|
return [] |
|
|
|
def analyze_checkpoint(num_samples=1000): |
|
""" |
|
Διαβάζει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών. |
|
""" |
|
if not os.path.exists(CHECKPOINT_FILE): |
|
return "Το αρχείο checkpoint δεν υπάρχει." |
|
|
|
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
|
lines = f.read().splitlines() |
|
|
|
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines |
|
|
|
language_counts = {} |
|
total = 0 |
|
for line in sample_lines: |
|
try: |
|
lang = detect(line) |
|
language_counts[lang] = language_counts.get(lang, 0) + 1 |
|
total += 1 |
|
except Exception as e: |
|
continue |
|
|
|
if total == 0: |
|
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση." |
|
|
|
report = "Αποτελέσματα Ανάλυσης:\n" |
|
for lang, count in language_counts.items(): |
|
report += f"Γλώσσα {lang}: {count/total*100:.2f}%\n" |
|
|
|
return report |
|
|
|
def collect_samples(dataset_name, configs, split): |
|
""" |
|
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES |
|
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION). |
|
""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
total_processed = len(load_checkpoint()) |
|
progress_messages = [f"📌 Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."] |
|
|
|
dataset_iterator = create_iterator(dataset_name, configs, split) |
|
new_texts = [] |
|
|
|
for text in dataset_iterator: |
|
if STOP_COLLECTION: |
|
progress_messages.append("⏹️ Η συλλογή διακόπηκε από το χρήστη.") |
|
break |
|
|
|
new_texts.append(text) |
|
total_processed += 1 |
|
|
|
if len(new_texts) >= CHUNK_SIZE: |
|
append_to_checkpoint(new_texts) |
|
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.") |
|
new_texts = [] |
|
|
|
if total_processed >= MAX_SAMPLES: |
|
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.") |
|
break |
|
|
|
if new_texts: |
|
append_to_checkpoint(new_texts) |
|
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).") |
|
|
|
return "\n".join(progress_messages) |
|
|
|
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): |
|
""" |
|
Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δείγματα που έχουν συλλεχθεί στο checkpoint. |
|
""" |
|
print("🚀 Ξεκινά η εκπαίδευση του tokenizer με τα δεδομένα του checkpoint...") |
|
all_texts = load_checkpoint() |
|
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR) |
|
|
|
|
|
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
|
|
|
|
|
encoded = trained_tokenizer.encode(test_text) |
|
decoded = trained_tokenizer.decode(encoded.ids) |
|
|
|
|
|
token_lengths = [len(t) for t in encoded.tokens] |
|
fig = plt.figure() |
|
plt.hist(token_lengths, bins=20) |
|
plt.xlabel('Μήκος Token') |
|
plt.ylabel('Συχνότητα') |
|
img_buffer = BytesIO() |
|
plt.savefig(img_buffer, format='png') |
|
plt.close() |
|
|
|
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}", |
|
decoded, |
|
img_buffer.getvalue()) |
|
|
|
|
|
|
|
def start_collection(dataset_name, configs, split): |
|
"""Ξεκινά τη συλλογή δειγμάτων (ή επανεκκινεί τη συλλογή αν έχει γίνει restart).""" |
|
msg = collect_samples(dataset_name, configs, split) |
|
return msg |
|
|
|
def stop_collection(): |
|
"""Θέτει το flag για διακοπή της συλλογής δειγμάτων.""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = True |
|
return "Η συλλογή σταμάτησε από το χρήστη." |
|
|
|
def restart_collection(): |
|
""" |
|
Επαναφέρει τη συλλογή διαγράφοντας το checkpoint και |
|
επαναφέροντας το flag ώστε να ξεκινήσει νέα συλλογή. |
|
""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
if os.path.exists(CHECKPOINT_FILE): |
|
os.remove(CHECKPOINT_FILE) |
|
return "Το checkpoint διαγράφτηκε. Μπορείς να ξεκινήσεις νέα συλλογή." |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Wikipedia Tokenizer Trainer with Collection, Analysis & Training") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name") |
|
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs") |
|
split = gr.Dropdown(choices=["train"], value="train", label="Split") |
|
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size") |
|
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency") |
|
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") |
|
start_btn = gr.Button("Start Collection") |
|
stop_btn = gr.Button("Stop Collection") |
|
analyze_btn = gr.Button("Analyze Samples") |
|
restart_btn = gr.Button("Restart Collection") |
|
train_btn = gr.Button("Train Tokenizer") |
|
with gr.Column(): |
|
progress = gr.Textbox(label="Progress", interactive=False, lines=10) |
|
results_text = gr.Textbox(label="Test Decoded Text", interactive=False) |
|
results_plot = gr.Image(label="Token Length Distribution") |
|
|
|
initial_file_value = TOKENIZER_FILE if os.path.exists(TOKENIZER_FILE) else None |
|
download_button = gr.File(label="Download Tokenizer", value=initial_file_value) |
|
|
|
|
|
start_btn.click(fn=start_collection, |
|
inputs=[dataset_name, configs, split], |
|
outputs=progress) |
|
|
|
stop_btn.click(fn=stop_collection, |
|
inputs=[], |
|
outputs=progress) |
|
|
|
analyze_btn.click(fn=lambda: analyze_checkpoint(1000), |
|
inputs=[], |
|
outputs=progress) |
|
|
|
restart_btn.click(fn=restart_collection, |
|
inputs=[], |
|
outputs=progress) |
|
|
|
train_btn.click(fn=train_tokenizer_fn, |
|
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text], |
|
outputs=[progress, results_text, results_plot]) |
|
|
|
demo.launch() |