Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import spacy | |
from typing import List, Dict, Tuple | |
import logging | |
import os | |
import gradio as gr | |
from fastapi.middleware.cors import CORSMiddleware | |
from concurrent.futures import ThreadPoolExecutor | |
from functools import partial | |
import time | |
import csv | |
from datetime import datetime | |
import threading | |
import random | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MAX_LENGTH = 512 | |
MODEL_NAME = "microsoft/deberta-v3-small" | |
WINDOW_SIZE = 6 | |
WINDOW_OVERLAP = 2 | |
CONFIDENCE_THRESHOLD = 0.65 | |
BATCH_SIZE = 8 # Reduced batch size for CPU | |
MAX_WORKERS = 4 # Number of worker threads for processing | |
class CSVLogger: | |
def __init__(self, log_dir="."): | |
"""Initialize the CSV logger. | |
Args: | |
log_dir: Directory to store CSV log files | |
""" | |
self.log_dir = log_dir | |
os.makedirs(log_dir, exist_ok=True) | |
# Create monthly CSV files | |
current_month = datetime.now().strftime('%Y-%m') | |
self.metrics_path = os.path.join(log_dir, f"metrics_{current_month}.csv") | |
self.text_path = os.path.join(log_dir, f"text_data_{current_month}.csv") | |
# Define headers | |
self.metrics_headers = [ | |
'entry_id', 'timestamp', 'word_count', 'mode', 'prediction', | |
'confidence', 'prediction_time_seconds', 'num_sentences' | |
] | |
self.text_headers = ['entry_id', 'timestamp', 'text'] | |
# Initialize the files if they don't exist | |
self._initialize_files() | |
# Create locks for thread safety | |
self.metrics_lock = threading.Lock() | |
self.text_lock = threading.Lock() | |
print(f"CSV logger initialized with files at: {os.path.abspath(self.metrics_path)}") | |
def _initialize_files(self): | |
"""Create the CSV files with headers if they don't exist.""" | |
# Initialize metrics file | |
if not os.path.exists(self.metrics_path): | |
with open(self.metrics_path, 'w', newline='') as f: | |
writer = csv.writer(f) | |
writer.writerow(self.metrics_headers) | |
# Initialize text data file | |
if not os.path.exists(self.text_path): | |
with open(self.text_path, 'w', newline='') as f: | |
writer = csv.writer(f) | |
writer.writerow(self.text_headers) | |
def log_prediction(self, prediction_data, store_text=True): | |
"""Log prediction data to CSV files. | |
Args: | |
prediction_data: Dictionary containing prediction metrics | |
store_text: Whether to store the full text | |
""" | |
# Generate a unique entry ID | |
entry_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{random.randint(1000, 9999)}" | |
# Extract text if present | |
text = prediction_data.pop('text', None) if store_text else None | |
# Ensure timestamp is present | |
if 'timestamp' not in prediction_data: | |
prediction_data['timestamp'] = datetime.now().isoformat() | |
# Add entry_id to metrics data | |
metrics_data = prediction_data.copy() | |
metrics_data['entry_id'] = entry_id | |
# Start a thread to write data | |
thread = threading.Thread( | |
target=self._write_to_csv, | |
args=(metrics_data, text, entry_id, store_text) | |
) | |
thread.daemon = True | |
thread.start() | |
def _write_to_csv(self, metrics_data, text, entry_id, store_text): | |
"""Write data to CSV files with retry mechanism.""" | |
max_retries = 5 | |
retry_delay = 0.5 | |
# Write metrics data | |
for attempt in range(max_retries): | |
try: | |
with self.metrics_lock: | |
with open(self.metrics_path, 'a', newline='') as f: | |
writer = csv.writer(f) | |
# Prepare row in the correct order based on headers | |
row = [ | |
metrics_data.get('entry_id', ''), | |
metrics_data.get('timestamp', ''), | |
metrics_data.get('word_count', 0), | |
metrics_data.get('mode', ''), | |
metrics_data.get('prediction', ''), | |
metrics_data.get('confidence', 0.0), | |
metrics_data.get('prediction_time_seconds', 0.0), | |
metrics_data.get('num_sentences', 0) | |
] | |
writer.writerow(row) | |
print(f"Successfully wrote metrics to CSV, entry_id: {entry_id}") | |
break | |
except Exception as e: | |
print(f"Error writing metrics to CSV (attempt {attempt+1}/{max_retries}): {e}") | |
time.sleep(retry_delay * (attempt + 1)) | |
else: | |
# If all retries fail, write to backup file | |
backup_path = os.path.join(self.log_dir, f"metrics_backup_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv") | |
try: | |
with open(backup_path, 'w', newline='') as f: | |
writer = csv.writer(f) | |
writer.writerow(self.metrics_headers) | |
row = [ | |
metrics_data.get('entry_id', ''), | |
metrics_data.get('timestamp', ''), | |
metrics_data.get('word_count', 0), | |
metrics_data.get('mode', ''), | |
metrics_data.get('prediction', ''), | |
metrics_data.get('confidence', 0.0), | |
metrics_data.get('prediction_time_seconds', 0.0), | |
metrics_data.get('num_sentences', 0) | |
] | |
writer.writerow(row) | |
print(f"Wrote metrics backup to {backup_path}") | |
except Exception as e: | |
print(f"Error writing metrics backup: {e}") | |
# Write text data if requested | |
if store_text and text: | |
for attempt in range(max_retries): | |
try: | |
with self.text_lock: | |
with open(self.text_path, 'a', newline='') as f: | |
writer = csv.writer(f) | |
# Handle potential newlines in text by replacing them | |
safe_text = text.replace('\n', ' ').replace('\r', ' ') if text else '' | |
writer.writerow([entry_id, metrics_data.get('timestamp', ''), safe_text]) | |
print(f"Successfully wrote text data to CSV, entry_id: {entry_id}") | |
break | |
except Exception as e: | |
print(f"Error writing text data to CSV (attempt {attempt+1}/{max_retries}): {e}") | |
time.sleep(retry_delay * (attempt + 1)) | |
else: | |
# If all retries fail, write to backup file | |
backup_path = os.path.join(self.log_dir, f"text_backup_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv") | |
try: | |
with open(backup_path, 'w', newline='') as f: | |
writer = csv.writer(f) | |
writer.writerow(self.text_headers) | |
safe_text = text.replace('\n', ' ').replace('\r', ' ') if text else '' | |
writer.writerow([entry_id, metrics_data.get('timestamp', ''), safe_text]) | |
print(f"Wrote text data backup to {backup_path}") | |
except Exception as e: | |
print(f"Error writing text data backup: {e}") | |
class TextWindowProcessor: | |
def __init__(self): | |
try: | |
self.nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
logger.info("Downloading spacy model...") | |
spacy.cli.download("en_core_web_sm") | |
self.nlp = spacy.load("en_core_web_sm") | |
if 'sentencizer' not in self.nlp.pipe_names: | |
self.nlp.add_pipe('sentencizer') | |
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer'] | |
self.nlp.disable_pipes(*disabled_pipes) | |
# Initialize thread pool for parallel processing | |
self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
def split_into_sentences(self, text: str) -> List[str]: | |
doc = self.nlp(text) | |
return [str(sent).strip() for sent in doc.sents] | |
def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]: | |
if len(sentences) < window_size: | |
return [" ".join(sentences)] | |
windows = [] | |
stride = window_size - overlap | |
for i in range(0, len(sentences) - window_size + 1, stride): | |
window = sentences[i:i + window_size] | |
windows.append(" ".join(window)) | |
return windows | |
def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]: | |
"""Create windows with better boundary handling""" | |
windows = [] | |
window_sentence_indices = [] | |
for i in range(len(sentences)): | |
# Calculate window boundaries centered on current sentence | |
half_window = window_size // 2 | |
start_idx = max(0, i - half_window) | |
end_idx = min(len(sentences), i + half_window + 1) | |
# Create the window | |
window = sentences[start_idx:end_idx] | |
windows.append(" ".join(window)) | |
window_sentence_indices.append(list(range(start_idx, end_idx))) | |
return windows, window_sentence_indices | |
class TextClassifier: | |
def __init__(self): | |
# Set thread configuration before any model loading or parallel work | |
if not torch.cuda.is_available(): | |
torch.set_num_threads(MAX_WORKERS) | |
torch.set_num_interop_threads(MAX_WORKERS) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model_name = MODEL_NAME | |
self.tokenizer = None | |
self.model = None | |
self.processor = TextWindowProcessor() | |
self.initialize_model() | |
def initialize_model(self): | |
"""Initialize the model and tokenizer.""" | |
logger.info("Initializing model and tokenizer...") | |
from transformers import DebertaV2TokenizerFast | |
self.tokenizer = DebertaV2TokenizerFast.from_pretrained( | |
self.model_name, | |
model_max_length=MAX_LENGTH, | |
use_fast=True | |
) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name, | |
num_labels=2 | |
).to(self.device) | |
model_path = "model_20250209_184929_acc1.0000.pt" | |
if os.path.exists(model_path): | |
logger.info(f"Loading custom model from {model_path}") | |
checkpoint = torch.load(model_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
logger.warning("Custom model file not found. Using base model.") | |
self.model.eval() | |
def quick_scan(self, text: str) -> Dict: | |
"""Perform a quick scan using simple window analysis.""" | |
if not text.strip(): | |
return { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_windows': 0 | |
} | |
sentences = self.processor.split_into_sentences(text) | |
windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP) | |
predictions = [] | |
# Process windows in smaller batches for CPU efficiency | |
for i in range(0, len(windows), BATCH_SIZE): | |
batch_windows = windows[i:i + BATCH_SIZE] | |
inputs = self.tokenizer( | |
batch_windows, | |
truncation=True, | |
padding=True, | |
max_length=MAX_LENGTH, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = F.softmax(outputs.logits, dim=-1) | |
for idx, window in enumerate(batch_windows): | |
prediction = { | |
'window': window, | |
'human_prob': probs[idx][1].item(), | |
'ai_prob': probs[idx][0].item(), | |
'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai' | |
} | |
predictions.append(prediction) | |
# Clean up GPU memory if available | |
del inputs, outputs, probs | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if not predictions: | |
return { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_windows': 0 | |
} | |
avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions) | |
avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions) | |
return { | |
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', | |
'confidence': max(avg_human_prob, avg_ai_prob), | |
'num_windows': len(predictions) | |
} | |
def detailed_scan(self, text: str) -> Dict: | |
"""Perform a detailed scan with improved sentence-level analysis.""" | |
# Clean up trailing whitespace | |
text = text.rstrip() | |
if not text.strip(): | |
return { | |
'sentence_predictions': [], | |
'highlighted_text': '', | |
'full_text': '', | |
'overall_prediction': { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_sentences': 0 | |
} | |
} | |
sentences = self.processor.split_into_sentences(text) | |
if not sentences: | |
return {} | |
# Create centered windows for each sentence | |
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE) | |
# Track scores for each sentence | |
sentence_appearances = {i: 0 for i in range(len(sentences))} | |
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))} | |
# Process windows in batches | |
for i in range(0, len(windows), BATCH_SIZE): | |
batch_windows = windows[i:i + BATCH_SIZE] | |
batch_indices = window_sentence_indices[i:i + BATCH_SIZE] | |
inputs = self.tokenizer( | |
batch_windows, | |
truncation=True, | |
padding=True, | |
max_length=MAX_LENGTH, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = F.softmax(outputs.logits, dim=-1) | |
# Attribute predictions with weighted scoring | |
for window_idx, indices in enumerate(batch_indices): | |
center_idx = len(indices) // 2 | |
center_weight = 0.7 # Higher weight for center sentence | |
edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight | |
for pos, sent_idx in enumerate(indices): | |
# Apply higher weight to center sentence | |
weight = center_weight if pos == center_idx else edge_weight | |
sentence_appearances[sent_idx] += weight | |
sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item() | |
sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item() | |
# Clean up memory | |
del inputs, outputs, probs | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Calculate final predictions with boundary smoothing | |
sentence_predictions = [] | |
for i in range(len(sentences)): | |
if sentence_appearances[i] > 0: | |
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i] | |
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i] | |
# Apply minimal smoothing at prediction boundaries | |
if i > 0 and i < len(sentences) - 1: | |
prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1] | |
prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1] | |
next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1] | |
next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1] | |
# Check if we're at a prediction boundary | |
current_pred = 'human' if human_prob > ai_prob else 'ai' | |
prev_pred = 'human' if prev_human > prev_ai else 'ai' | |
next_pred = 'human' if next_human > next_ai else 'ai' | |
if current_pred != prev_pred or current_pred != next_pred: | |
# Small adjustment at boundaries | |
smooth_factor = 0.1 | |
human_prob = (human_prob * (1 - smooth_factor) + | |
(prev_human + next_human) * smooth_factor / 2) | |
ai_prob = (ai_prob * (1 - smooth_factor) + | |
(prev_ai + next_ai) * smooth_factor / 2) | |
sentence_predictions.append({ | |
'sentence': sentences[i], | |
'human_prob': human_prob, | |
'ai_prob': ai_prob, | |
'prediction': 'human' if human_prob > ai_prob else 'ai', | |
'confidence': max(human_prob, ai_prob) | |
}) | |
return { | |
'sentence_predictions': sentence_predictions, | |
'highlighted_text': self.format_predictions_html(sentence_predictions), | |
'full_text': text, | |
'overall_prediction': self.aggregate_predictions(sentence_predictions) | |
} | |
def format_predictions_html(self, sentence_predictions: List[Dict]) -> str: | |
"""Format predictions as HTML with color-coding.""" | |
html_parts = [] | |
for pred in sentence_predictions: | |
sentence = pred['sentence'] | |
confidence = pred['confidence'] | |
if confidence >= CONFIDENCE_THRESHOLD: | |
if pred['prediction'] == 'human': | |
color = "#90EE90" # Light green | |
else: | |
color = "#FFB6C6" # Light red | |
else: | |
if pred['prediction'] == 'human': | |
color = "#E8F5E9" # Very light green | |
else: | |
color = "#FFEBEE" # Very light red | |
html_parts.append(f'<span style="background-color: {color};">{sentence}</span>') | |
return " ".join(html_parts) | |
def aggregate_predictions(self, predictions: List[Dict]) -> Dict: | |
"""Aggregate predictions from multiple sentences into a single prediction.""" | |
if not predictions: | |
return { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_sentences': 0 | |
} | |
total_human_prob = sum(p['human_prob'] for p in predictions) | |
total_ai_prob = sum(p['ai_prob'] for p in predictions) | |
num_sentences = len(predictions) | |
avg_human_prob = total_human_prob / num_sentences | |
avg_ai_prob = total_ai_prob / num_sentences | |
return { | |
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', | |
'confidence': max(avg_human_prob, avg_ai_prob), | |
'num_sentences': num_sentences | |
} | |
# Initialize the logger | |
csv_logger = CSVLogger(log_dir=".") | |
# Add file listing endpoint for debugging | |
def list_files(): | |
"""List all files in the current directory and subdirectories.""" | |
all_files = [] | |
for root, dirs, files in os.walk('.'): | |
for file in files: | |
all_files.append(os.path.join(root, file)) | |
return all_files | |
def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple: | |
"""Analyze text using specified mode and return formatted results.""" | |
# Start timing the prediction | |
start_time = time.time() | |
# Count words in the text | |
word_count = len(text.split()) | |
# If text is less than 200 words and detailed mode is selected, switch to quick mode | |
original_mode = mode | |
if word_count < 200 and mode == "detailed": | |
mode = "quick" | |
if mode == "quick": | |
result = classifier.quick_scan(text) | |
prediction = result['prediction'] | |
confidence = result['confidence'] | |
num_windows = result['num_windows'] | |
quick_analysis = f""" | |
PREDICTION: {prediction.upper()} | |
Confidence: {confidence*100:.1f}% | |
Windows analyzed: {num_windows} | |
""" | |
# Add note if mode was switched | |
if original_mode == "detailed": | |
quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis." | |
output = ( | |
text, # No highlighting in quick mode | |
"Quick scan mode - no sentence-level analysis available", | |
quick_analysis | |
) | |
# End timing | |
end_time = time.time() | |
prediction_time = end_time - start_time | |
# Log the data | |
log_data = { | |
"timestamp": datetime.now().isoformat(), | |
"word_count": word_count, | |
"mode": mode, | |
"prediction": prediction, | |
"confidence": confidence, | |
"prediction_time_seconds": prediction_time, | |
"num_sentences": 0, # No sentence analysis in quick mode | |
"text": text | |
} | |
# Log to CSV | |
print(f"Logging prediction data: word_count={word_count}, mode={mode}, prediction={prediction}") | |
csv_logger.log_prediction(log_data) | |
else: | |
analysis = classifier.detailed_scan(text) | |
prediction = analysis['overall_prediction']['prediction'] | |
confidence = analysis['overall_prediction']['confidence'] | |
num_sentences = analysis['overall_prediction']['num_sentences'] | |
detailed_analysis = [] | |
for pred in analysis['sentence_predictions']: | |
pred_confidence = pred['confidence'] * 100 | |
detailed_analysis.append(f"Sentence: {pred['sentence']}") | |
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}") | |
detailed_analysis.append(f"Confidence: {pred_confidence:.1f}%") | |
detailed_analysis.append("-" * 50) | |
final_pred = analysis['overall_prediction'] | |
overall_result = f""" | |
FINAL PREDICTION: {final_pred['prediction'].upper()} | |
Overall confidence: {final_pred['confidence']*100:.1f}% | |
Number of sentences analyzed: {final_pred['num_sentences']} | |
""" | |
output = ( | |
analysis['highlighted_text'], | |
"\n".join(detailed_analysis), | |
overall_result | |
) | |
# End timing | |
end_time = time.time() | |
prediction_time = end_time - start_time | |
# Log the data | |
log_data = { | |
"timestamp": datetime.now().isoformat(), | |
"word_count": word_count, | |
"mode": mode, | |
"prediction": prediction, | |
"confidence": confidence, | |
"prediction_time_seconds": prediction_time, | |
"num_sentences": num_sentences, | |
"text": text | |
} | |
# Log to CSV | |
print(f"Logging prediction data: word_count={word_count}, mode={mode}, prediction={prediction}") | |
csv_logger.log_prediction(log_data) | |
return output | |
# Initialize the classifier globally | |
classifier = TextClassifier() | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=lambda text, mode: analyze_text(text, mode, classifier), | |
inputs=[ | |
gr.Textbox( | |
lines=8, | |
placeholder="Enter text to analyze...", | |
label="Input Text" | |
), | |
gr.Radio( | |
choices=["quick", "detailed"], | |
value="quick", | |
label="Analysis Mode", | |
info="Quick mode for faster analysis, Detailed mode for sentence-level analysis" | |
) | |
], | |
outputs=[ | |
gr.HTML(label="Highlighted Analysis"), | |
gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10), | |
gr.Textbox(label="Overall Result", lines=4) | |
], | |
title="AI Text Detector", | |
description="Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis. 200+ words suggested for accurate predictions. Note: For testing purposes, text and analysis data will be recorded.", | |
api_name="predict", | |
flagging_mode="never" | |
) | |
app = demo.app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # For development | |
allow_credentials=True, | |
allow_methods=["GET", "POST", "OPTIONS"], # Explicitly list methods | |
allow_headers=["*"], | |
) | |
# Add file listing endpoint for debugging | |
async def get_files(): | |
return {"files": list_files()} | |
# Ensure CORS is applied before launching | |
if __name__ == "__main__": | |
# Create empty CSV files if they don't exist | |
current_month = datetime.now().strftime('%Y-%m') | |
metrics_path = f"metrics_{current_month}.csv" | |
text_path = f"text_data_{current_month}.csv" | |
print(f"Current directory: {os.getcwd()}") | |
print(f"Looking for CSV files: {metrics_path}, {text_path}") | |
if not os.path.exists(metrics_path): | |
print(f"Creating metrics CSV file: {metrics_path}") | |
if not os.path.exists(text_path): | |
print(f"Creating text data CSV file: {text_path}") | |
demo.queue() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |