from fastapi import APIRouter from datetime import datetime import time from datasets import load_dataset from sklearn.metrics import accuracy_score import os from concurrent.futures import ThreadPoolExecutor from typing import List, Dict, Tuple import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from .utils.evaluation import TextEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking # Disable torch compile os.environ["TORCH_COMPILE_DISABLE"] = "1" router = APIRouter() DESCRIPTION = "Climate Guard Toxic Agent Classifier" ROUTE = "/text" class TextClassifier: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" max_retries = 3 model_name = "Tonic/climate-guard-toxic-agent" for attempt in range(max_retries): try: # Load config first config = AutoConfig.from_pretrained(model_name) # Initialize tokenizer with specific model type self.tokenizer = AutoTokenizer.from_pretrained( model_name, model_max_length=512, padding_side='right', truncation_side='right' ) # Initialize model with config self.model = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, torch_dtype=torch.float32 ) self.model.to(self.device) self.model.eval() print("Model initialized successfully") break except Exception as e: if attempt == max_retries - 1: raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}") print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}") time.sleep(1) def predict_single(self, text: str) -> int: """Predict single text instance""" try: # Tokenize with explicit padding and truncation inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding='max_length' ).to(self.device) # Get prediction with torch.no_grad(): outputs = self.model(**inputs) predictions = torch.argmax(outputs.logits, dim=-1) return predictions.item() except Exception as e: print(f"Error in single prediction: {str(e)}") return 0 # Return default prediction on error def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]: """Process a batch of texts and return their predictions""" try: print(f"Processing batch {batch_idx} with {len(batch)} items") # Process entire batch at once inputs = self.tokenizer( batch, return_tensors="pt", truncation=True, max_length=512, padding='max_length' ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) predictions = torch.argmax(outputs.logits, dim=-1).tolist() print(f"Completed batch {batch_idx} with {len(predictions)} predictions") return predictions, batch_idx except Exception as e: print(f"Error in batch {batch_idx}: {str(e)}") return [0] * len(batch), batch_idx @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest): """Evaluate text classification for climate disinformation detection.""" # Get space info username, space_url = get_space_info() # Define the label mapping LABEL_MAPPING = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7 } # Load and prepare the dataset dataset = load_dataset(request.dataset_name) dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) test_dataset = dataset["test"] # Start tracking emissions start_tracking() true_labels = test_dataset["label"] # Initialize the model once classifier = TextClassifier() # Prepare batches batch_size = 32 # Increased batch size for efficiency quotes = test_dataset["quote"] num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0) batches = [ quotes[i * batch_size:(i + 1) * batch_size] for i in range(num_batches) ] # Process batches sequentially to avoid memory issues predictions = [] for idx, batch in enumerate(batches): batch_preds, _ = classifier.process_batch(batch, idx) predictions.extend(batch_preds) print(f"Processed batch {idx + 1}/{num_batches}") # Stop tracking emissions emissions_data = stop_tracking() # Calculate accuracy accuracy = accuracy_score(true_labels, predictions) print("accuracy:", accuracy) # Prepare results results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } print("results:", results) return results