Zen0 commited on
Commit
c5c3dc9
·
verified ·
1 Parent(s): 5e01226

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +7 -6
tasks/text.py CHANGED
@@ -8,9 +8,11 @@ from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
- import torch
12
 
13
  import numpy as np
 
 
 
14
 
15
 
16
  router = APIRouter()
@@ -61,14 +63,13 @@ async def evaluate_text(request: TextEvaluationRequest):
61
  #--------------------------------------------------------------------------------------------
62
 
63
 
64
- # Model and Tokenizer
65
  model_name = "Zen0/FrugalDisinfoHunter"
66
  tokenizer = AutoTokenizer.from_pretrained(model_name)
67
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
68
-
69
 
70
  # Tokenize the test data
71
- test_texts = test_dataset["text"] # Extracting the 'text' column (quotes)
72
  inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
73
 
74
  # Move model and inputs to GPU if available
@@ -81,9 +82,9 @@ async def evaluate_text(request: TextEvaluationRequest):
81
  outputs = model(**inputs)
82
  logits = outputs.logits
83
 
84
- # Get predictions from the logits (choose the class with the highest logit)
85
  predictions = torch.argmax(logits, dim=-1).cpu().numpy()
86
-
87
  true_labels = test_dataset['label']
88
 
89
  #--------------------------------------------------------------------------------------------
 
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
11
 
12
  import numpy as np
13
+ from climate_model import ModelWrapper
14
+ from preprocessing import ClimateTextPreprocessor
15
+ import torch
16
 
17
 
18
  router = APIRouter()
 
63
  #--------------------------------------------------------------------------------------------
64
 
65
 
66
+ # Model and Tokenizer
67
  model_name = "Zen0/FrugalDisinfoHunter"
68
  tokenizer = AutoTokenizer.from_pretrained(model_name)
69
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
70
 
71
  # Tokenize the test data
72
+ test_texts = test_dataset["quote"] # Changed from "text" to "quote"
73
  inputs = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
74
 
75
  # Move model and inputs to GPU if available
 
82
  outputs = model(**inputs)
83
  logits = outputs.logits
84
 
85
+ # Get predictions from the logits
86
  predictions = torch.argmax(logits, dim=-1).cpu().numpy()
87
+
88
  true_labels = test_dataset['label']
89
 
90
  #--------------------------------------------------------------------------------------------