File size: 5,744 Bytes
7c5aa99 f94c5ea 0430da2 4a4435c 3d37920 0430da2 4a4435c 0430da2 f94c5ea 4a4435c f94c5ea 3d37920 0430da2 4a4435c 3d37920 4a4435c 3d37920 4a4435c 7c5aa99 0430da2 a9ae246 3d37920 a9ae246 3d37920 a9ae246 3d37920 1c51cb8 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 9d049fd f94c5ea 3d37920 9d049fd 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 0430da2 4a4435c 3d37920 9d049fd 0430da2 3d37920 12ccc3f 3d37920 0044b58 3d37920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
import tempfile
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
# Ρυθμίσεις checkpointing
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
CHUNK_SIZE = 1000 # Μέγεθος batch για checkpoint
def fetch_splits(dataset_name):
"""Ανάκτηση των splits του dataset από το Hugging Face."""
try:
response = requests.get(f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", timeout=10)
response.raise_for_status()
data = response.json()
splits_info = {}
for split in data['splits']:
config = split['config']
split_name = split['split']
if config not in splits_info:
splits_info[config] = []
splits_info[config].append(split_name)
return {
"splits": splits_info,
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}"
}
except Exception as e:
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def train_and_test(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαίδευση του tokenizer και δοκιμή του."""
print("🚀 Ξεκινά η διαδικασία εκπαίδευσης...")
all_texts = load_checkpoint()
total_processed = len(all_texts)
print(f"📌 Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint.")
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
new_texts.append(text)
total_processed += 1
if len(new_texts) >= CHUNK_SIZE:
append_to_checkpoint(new_texts)
print(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if new_texts:
append_to_checkpoint(new_texts)
print(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
# Εκπαίδευση του tokenizer
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}", decoded, img_buffer.getvalue()
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Checkpointing")
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
train_btn = gr.Button("Train")
progress = gr.Textbox(label="Progress", interactive=False)
results_plot = gr.Image(label="Token Length Distribution")
# download_button = gr.File(label="Download Tokenizer", value=TOKENIZER_FILE)
# Έλεγχος αν υπάρχει ήδη ο tokenizer
if os.path.exists(TOKENIZER_FILE):
initial_file_value = TOKENIZER_FILE
else:
initial_file_value = None # Αν δεν υπάρχει, ξεκινάει ως None
download_button = gr.File(label="Download Tokenizer", value=initial_file_value)
train_btn.click(train_and_test, [dataset_name, configs, split, vocab_size, min_freq, test_text], [progress, test_text, results_plot])
demo.launch() |