GPT2-PBE / app.py
tymbos's picture
Update app.py
d6a5933 verified
raw
history blame
10.2 kB
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
CHUNK_SIZE = 1000 # Μέγεθος batch για checkpoint
MAX_SAMPLES = 3000000 # Όριο δειγμάτων (προσαρμόστε όπως χρειάζεται)
# Παγκόσμια μεταβλητή ελέγχου συλλογής
STOP_COLLECTION = False
def fetch_splits(dataset_name):
"""Ανάκτηση των splits του dataset από το Hugging Face."""
try:
response = requests.get(f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", timeout=10)
response.raise_for_status()
data = response.json()
splits_info = {}
for split in data['splits']:
config = split['config']
split_name = split['split']
if config not in splits_info:
splits_info[config] = []
splits_info[config].append(split_name)
return {
"splits": splits_info,
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}"
}
except Exception as e:
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def analyze_checkpoint(num_samples=1000):
"""
Διαβάζει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών.
"""
if not os.path.exists(CHECKPOINT_FILE):
return "Το αρχείο checkpoint δεν υπάρχει."
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines
language_counts = {}
total = 0
for line in sample_lines:
try:
lang = detect(line)
language_counts[lang] = language_counts.get(lang, 0) + 1
total += 1
except Exception as e:
continue
if total == 0:
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση."
report = "Αποτελέσματα Ανάλυσης:\n"
for lang, count in language_counts.items():
report += f"Γλώσσα {lang}: {count/total*100:.2f}%\n"
return report
def collect_samples(dataset_name, configs, split):
"""
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION).
"""
global STOP_COLLECTION
STOP_COLLECTION = False # Βεβαιωνόμαστε ότι η συλλογή ξεκινάει ανενεργή τη διακοπή
total_processed = len(load_checkpoint())
progress_messages = [f"📌 Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."]
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
if STOP_COLLECTION:
progress_messages.append("⏹️ Η συλλογή διακόπηκε από το χρήστη.")
break
new_texts.append(text)
total_processed += 1
if len(new_texts) >= CHUNK_SIZE:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if total_processed >= MAX_SAMPLES:
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.")
break
if new_texts:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""
Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δείγματα που έχουν συλλεχθεί στο checkpoint.
"""
print("🚀 Ξεκινά η εκπαίδευση του tokenizer με τα δεδομένα του checkpoint...")
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}",
decoded,
img_buffer.getvalue())
# Callbacks κουμπιών
def start_collection(dataset_name, configs, split):
"""Ξεκινά τη συλλογή δειγμάτων (ή επανεκκινεί τη συλλογή αν έχει γίνει restart)."""
msg = collect_samples(dataset_name, configs, split)
return msg
def stop_collection():
"""Θέτει το flag για διακοπή της συλλογής δειγμάτων."""
global STOP_COLLECTION
STOP_COLLECTION = True
return "Η συλλογή σταμάτησε από το χρήστη."
def restart_collection():
"""
Επαναφέρει τη συλλογή διαγράφοντας το checkpoint και
επαναφέροντας το flag ώστε να ξεκινήσει νέα συλλογή.
"""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
return "Το checkpoint διαγράφτηκε. Μπορείς να ξεκινήσεις νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Collection, Analysis & Training")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
start_btn = gr.Button("Start Collection")
stop_btn = gr.Button("Stop Collection")
analyze_btn = gr.Button("Analyze Samples")
restart_btn = gr.Button("Restart Collection")
train_btn = gr.Button("Train Tokenizer")
with gr.Column():
progress = gr.Textbox(label="Progress", interactive=False, lines=10)
results_text = gr.Textbox(label="Test Decoded Text", interactive=False)
results_plot = gr.Image(label="Token Length Distribution")
# Έλεγχος ύπαρξης του tokenizer για download
initial_file_value = TOKENIZER_FILE if os.path.exists(TOKENIZER_FILE) else None
download_button = gr.File(label="Download Tokenizer", value=initial_file_value)
# Συνδέουμε τα κουμπιά με τις συναρτήσεις
start_btn.click(fn=start_collection,
inputs=[dataset_name, configs, split],
outputs=progress)
stop_btn.click(fn=stop_collection,
inputs=[],
outputs=progress)
analyze_btn.click(fn=lambda: analyze_checkpoint(1000),
inputs=[],
outputs=progress)
restart_btn.click(fn=restart_collection,
inputs=[],
outputs=progress)
train_btn.click(fn=train_tokenizer_fn,
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text],
outputs=[progress, results_text, results_plot])
demo.launch()