GPT2-PBE / app.py
tymbos's picture
Update app.py
c259678 verified
raw
history blame
9.65 kB
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
import requests
import time
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = os.getcwd() # Χρησιμοποιεί τον τρέχοντα φάκελο
#TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000 # Αυξημένο όριο δειγμάτων
DEFAULT_CHUNK_SIZE = 200000 # Μεγαλύτερο chunk size
BATCH_SIZE = 1000 # Μέγεθος batch για φόρτωση δεδομένων
NUM_WORKERS = 4 # Αριθμός workers για πολυνηματική επεξεργασία
# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False
# Καταγραφή εκκίνησης
startup_log = f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n"
print(startup_log)
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων με ομαδοποίηση."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
batch = "\n".join(texts) + "\n"
f.write(batch)
def create_iterator(dataset_name, configs, split):
"""Βελτιωμένο iterator με batch φόρτωση και caching."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True,
cache_dir="./dataset_cache" # Ενεργοποίηση cache
)
# Φόρτωση δεδομένων σε batches
while True:
batch = list(dataset.take(BATCH_SIZE))
if not batch:
break
dataset = dataset.skip(BATCH_SIZE)
# Πολυνηματική επεξεργασία batch
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
processed_texts = list(executor.map(process_example, batch))
yield from filter(None, processed_texts)
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης: {config}: {e}")
def process_example(example):
"""Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
try:
text = example.get('text', '').strip()
if text and detect(text) in ['el', 'en']: # Φιλτράρισμα γλώσσας
return text
return None
except:
return None
def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
"""Βελτιωμένη συλλογή δεδομένων με μεγάλα chunks."""
global STOP_COLLECTION
STOP_COLLECTION = False
total_processed = len(load_checkpoint())
progress_messages = [
f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}",
f"⚙️ Ρυθμίσεις: Chunk Size={chunk_size}, Workers={NUM_WORKERS}"
]
dataset_iterator = create_iterator(dataset_name, configs, split)
chunk = []
while not STOP_COLLECTION and total_processed < max_samples:
try:
# Φόρτωση chunk
while len(chunk) < chunk_size:
text = next(dataset_iterator)
if text:
chunk.append(text)
total_processed += 1
if total_processed >= max_samples:
break
# Αποθήκευση chunk
if chunk:
append_to_checkpoint(chunk)
progress_messages.append(
f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})"
)
chunk = []
# Εκκαθάριση μνήμης
gc.collect()
except StopIteration:
progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
break
except Exception as e:
progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
break
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Βελτιωμένη εκπαίδευση tokenizer με χρήση cache."""
print("🚀 Εκκίνηση εκπαίδευσης...")
all_texts = load_checkpoint()
# Παράλληλη επεξεργασία για εκπαίδευση
tokenizer = train_tokenizer(
all_texts,
vocab_size=vocab_size,
min_frequency=min_freq,
output_dir=TOKENIZER_DIR,
num_threads=NUM_WORKERS # Παράλληλη επεξεργασία
)
# Φόρτωση και δοκιμή tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Δημιουργία γραφήματος
fig, ax = plt.subplots()
ax.hist([len(t) for t in encoded.tokens], bins=20)
ax.set_xlabel('Μήκος Token')
ax.set_ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return ("✅ Εκπαίδευση ολοκληρώθηκε!", decoded, Image.open(img_buffer))
print(f"Ο tokenizer αποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}")
def analyze_checkpoint():
"""Νέα λειτουργία ανάλυσης δεδομένων."""
texts = load_checkpoint()
if not texts:
return "Δεν βρέθηκαν δεδομένα για ανάλυση."
# Βασική στατιστική
total_chars = sum(len(t) for t in texts)
avg_length = total_chars / len(texts) if texts else 0
# Ανάλυση γλώσσας
languages = {}
for t in texts[:1000]: # Δειγματοληψία για ταχύτητα
try:
lang = detect(t)
languages[lang] = languages.get(lang, 0) + 1
except:
continue
report = [
f"📊 Σύνολο δειγμάτων: {len(texts)}",
f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
"🌍 Γλώσσες (δείγμα 1000):",
*[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
]
return "\n".join(report)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Βελτιωμένος Wikipedia Tokenizer Trainer")
with gr.Row():
with gr.Column(scale=2):
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
split = gr.Dropdown(["train"], value="train", label="Split")
chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
vocab_size = gr.Slider(20000, 200000, value=50000, step=10000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Maximum Samples")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
restart_btn = gr.Button("Restart")
analyze_btn = gr.Button("Analyze Data")
train_btn = gr.Button("Train Tokenizer", variant="primary")
with gr.Column(scale=3):
progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
gr.Markdown("### Αποτελέσματα")
decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
token_distribution = gr.Image(label="Κατανομή Tokens")
# Event handlers
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
stop_btn.click(lambda: "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
restart_btn.click(lambda: "🔄 Επαναφορά...", None, progress).then(restart_collection, None, progress)
analyze_btn.click(analyze_checkpoint, None, progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.queue(concurrency_count=4).launch()