Spaces:
Running
Running
import json | |
import os | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from sklearn.metrics import accuracy_score | |
from torch.utils.data import DataLoader | |
from transformers import Trainer, TrainingArguments | |
import time | |
import requests | |
from bs4 import BeautifulSoup | |
import tempfile | |
import zipfile | |
import mimetypes | |
from tqdm import tqdm | |
import logging | |
import gradio as gr | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- URL and File Processing Functions --- | |
def fetch_content(url, retries=3): | |
for attempt in range(retries): | |
try: | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
logger.info(f"Successfully fetched content from {url}") | |
return response.text | |
except requests.RequestException as e: | |
logger.error(f"Error fetching {url} (attempt {attempt + 1}/{retries}): {e}") | |
if attempt == retries - 1: | |
return None | |
def extract_text(html): | |
if not html: | |
logger.warning("Empty HTML content provided for extraction.") | |
return "" | |
soup = BeautifulSoup(html, 'html.parser') | |
for script in soup(["script", "style"]): | |
script.decompose() | |
text = soup.get_text() | |
lines = (line.strip() for line in text.splitlines()) | |
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) | |
extracted_text = '\n'.join(chunk for chunk in chunks if chunk) | |
logger.info("Text extraction completed.") | |
return extracted_text | |
def process_urls(urls): | |
dataset = [] | |
for url in tqdm(urls, desc="Fetching URLs"): | |
if not url.startswith("http://") and not url.startswith("https://"): | |
logger.warning(f"Invalid URL format: {url}") | |
continue | |
html = fetch_content(url) | |
if html: | |
text = extract_text(html) | |
if text: | |
dataset.append({"source": "url", "url": url, "content": text}) | |
else: | |
logger.warning(f"No text extracted from {url}") | |
else: | |
logger.error(f"Failed to fetch content from {url}") | |
time.sleep(1) | |
return dataset | |
def process_file(file): | |
dataset = [] | |
with tempfile.TemporaryDirectory() as temp_dir: | |
if zipfile.is_zipfile(file.name): | |
with zipfile.ZipFile(file.name, 'r') as zip_ref: | |
zip_ref.extractall(temp_dir) | |
for root, _, files in os.walk(temp_dir): | |
for filename in files: | |
filepath = os.path.join(root, filename) | |
mime_type, _ = mimetypes.guess_type(filepath) | |
if mime_type and mime_type.startswith('text'): | |
with open(filepath, 'r', errors='ignore') as f: | |
content = f.read() | |
if content.strip(): | |
dataset.append({"source": "file", "filename": filename, "content": content}) | |
else: | |
logger.warning(f"File {filename} is empty.") | |
else: | |
logger.warning(f"File {filename} is not a text file.") | |
dataset.append({"source": "file", "filename": filename, "content": "Binary file - content not extracted"}) | |
else: | |
mime_type, _ = mimetypes.guess_type(file.name) | |
if mime_type and mime_type.startswith('text'): | |
content = file.read().decode('utf-8', errors='ignore') | |
if content.strip(): | |
dataset.append({"source": "file", "filename": os.path.basename(file.name), "content": content}) | |
else: | |
logger.warning(f"Uploaded file {file.name} is empty.") | |
else: | |
logger.warning(f"Uploaded file {file.name} is not a text file.") | |
dataset.append({"source": "file", "filename": os.path.basename(file.name), "content": "Binary file - content not extracted"}) | |
return dataset | |
def create_dataset(urls, file, text_input): | |
dataset = [] | |
if urls: | |
dataset.extend(process_urls([url.strip() for url in urls.split(',') if url.strip()])) | |
if file: | |
dataset.extend(process_file(file)) | |
if text_input: | |
dataset.append({"source": "input", "content": text_input}) | |
logger.info(f"Dataset created with {len(dataset)} entries.") | |
output_file = 'combined_dataset.json' | |
with open(output_file, 'w') as f: | |
json.dump(dataset, f, indent=2) | |
return output_file | |
# --- Model Training and Evaluation Functions --- | |
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self, data, tokenizer, max_length=512): | |
self.data = data | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
try: | |
text = self.data[idx]['content'] # Fixed the key to 'content' | |
label = self.data[idx].get('label', 0) | |
encoding = self.tokenizer.encode_plus( | |
text, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
return { | |
'input_ids': encoding['input_ids'].squeeze(), | |
'attention_mask': encoding['attention_mask'].squeeze(), | |
'labels': torch.tensor(label, dtype=torch.long) | |
} | |
except Exception as e: | |
logger.error(f"Error in processing item {idx}: {e}") | |
raise | |
def train_model(model_name, data, batch_size, epochs, learning_rate=1e-5, max_length=2048): | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
dataset = CustomDataset(data, tokenizer, max_length=max_length) | |
if len(dataset) == 0: | |
logger.error("The dataset is empty. Please check the input data.") | |
return None, None | |
train_size = int(0.8 * len(dataset)) | |
val_size = len(dataset) - train_size | |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=epochs, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
eval_strategy='epoch', | |
save_strategy='epoch', | |
learning_rate=learning_rate, | |
save_steps=500, | |
load_best_model_at_end=True, | |
metric_for_best_model='accuracy', | |
greater_is_better=True, | |
save_total_limit=2, | |
seed=42, | |
dataloader_num_workers=4, | |
fp16=torch.cuda.is_available() | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=lambda pred: { | |
'accuracy': accuracy_score(pred.label_ids, pred.predictions.argmax(-1)) | |
} | |
) | |
logger.info("Starting model training...") | |
start_time = time.time() | |
trainer.train() | |
end_time = time.time() | |
logger.info(f'Training time: {end_time - start_time:.2f} seconds') | |
logger.info("Evaluating model...") | |
eval_result = trainer.evaluate() | |
logger.info(f'Evaluation result: {eval_result}') | |
trainer.save_model('./model') | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error during training: {e}") | |
raise | |
def deploy_model(model, tokenizer): | |
try: | |
model.save_pretrained('./model') | |
tokenizer.save_pretrained('./model') | |
deployment_script = f''' | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
model = AutoModelForSequenceClassification.from_pretrained('./model') | |
tokenizer = AutoTokenizer.from_pretrained('./model') | |
def predict(text): | |
encoding = tokenizer.encode_plus( | |
text, | |
max_length=512, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
input_ids = encoding['input_ids'].to('cuda' if torch.cuda.is_available() else 'cpu') | |
attention_mask = encoding['attention_mask'].to('cuda' if torch.cuda.is_available() else 'cpu') | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
return torch.argmax(logits, dim=1).cpu().numpy()[0] | |
''' | |
with open('./deployment.py', 'w') as f: | |
f.write(deployment_script) | |
logger.info('Model deployed successfully. To use the model, run: python deployment.py') | |
except Exception as e: | |
logger.error(f"Error deploying model: {e}") | |
raise | |
# Gradio Interface | |
def gradio_interface(urls, file, text_input, model_name, batch_size, epochs): | |
try: | |
dataset_file = create_dataset(urls, file, text_input) | |
with open(dataset_file, 'r') as f: | |
dataset = json.load(f) | |
if not dataset: | |
return "Error: The dataset is empty. Please check your inputs." | |
model, tokenizer = train_model(model_name, dataset, batch_size, epochs) | |
deploy_model(model, tokenizer) | |
return dataset_file | |
except Exception as e: | |
logger.error(f"Error in gradio_interface: {e}") | |
return f"An error occurred: {str(e)}" | |
# Gradio Interface Setup | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(lines=5, label="Enter comma-separated URLs", placeholder="http://example.com, https://example.org"), | |
gr.File(label="Upload file (including zip files)", type="filepath"), | |
gr.Textbox(lines=10, label="Enter or paste large text", placeholder="Your text here..."), | |
gr.Textbox(label="Model name", value="distilbert-base-uncased"), | |
gr.Number(label="Batch size", value=8, precision=0, step=1), | |
gr.Number(label="Epochs", value=3, precision=0, step=1), | |
], | |
outputs=gr.File(label="Download Combined Dataset"), | |
title="Dataset Creation and Model Training", | |
description="Enter URLs, upload files (including zip files), and/or paste text to create a dataset and train a model.", | |
theme="default", | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |