Terry Zhang commited on
Commit
3b83e0c
·
1 Parent(s): 243d40e

update code to include tree classifier

Browse files
tasks/text.py CHANGED
@@ -3,8 +3,8 @@ from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
- from skops.hub_utils import download
7
  from skops.io import load
 
8
 
9
 
10
  from .utils.evaluation import TextEvaluationRequest
@@ -15,12 +15,40 @@ router = APIRouter()
15
  DESCRIPTION = "Random Baseline"
16
  ROUTE = "/text"
17
 
18
- MODEL_PATH = "tasks/text_models/xgb_pipeline.skops"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  @router.post(ROUTE, tags=["Text Task"],
22
  description=DESCRIPTION)
23
- async def evaluate_text(request: TextEvaluationRequest):
 
24
  """
25
  Evaluate text classification for climate disinformation detection.
26
 
@@ -65,8 +93,10 @@ async def evaluate_text(request: TextEvaluationRequest):
65
  # Make random predictions (placeholder for actual model inference)
66
  true_labels = test_dataset["label"]
67
 
68
- model = load(MODEL_PATH)
69
- predictions = model.predict(test_dataset["text"])
 
 
70
 
71
  #--------------------------------------------------------------------------------------------
72
  # YOUR MODEL INFERENCE STOPS HERE
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
 
6
  from skops.io import load
7
+ from .utils.text_preprocessor import TextPreprocessor
8
 
9
 
10
  from .utils.evaluation import TextEvaluationRequest
 
15
  DESCRIPTION = "Random Baseline"
16
  ROUTE = "/text"
17
 
18
+ models_description = {
19
+ "baseline": "random baseline",
20
+ "tfidf_xgb": "TF-IDF vectorizer and XGBoost classifier",
21
+ }
22
+
23
+ # Some code borrowed from Nonnormalizable
24
+
25
+ def baseline_model(dataset_length: int):
26
+ # Make random predictions (placeholder for actual model inference)
27
+ predictions = [random.randint(0, 7) for _ in range(dataset_length)]
28
+
29
+ return predictions
30
+
31
+ def tree_classifier(test_dataset: dict, model: str):
32
+ texts = test_dataset["quote"]
33
+
34
+ model_path = f"models/frugalai_{model}"
35
+
36
+ model = load(model_path,
37
+ trusted=[
38
+ '__main__.TextPreprocessor',
39
+ 'nltk.stem.wordnet.WordNetLemmatizer',
40
+ 'xgboost.core.Booster',
41
+ 'xgboost.sklearn.XGBClassifier'])
42
+
43
+ predictions = model.predict(texts)
44
+
45
+ return predictions
46
 
47
 
48
  @router.post(ROUTE, tags=["Text Task"],
49
  description=DESCRIPTION)
50
+ async def evaluate_text(request: TextEvaluationRequest,
51
+ model: str = "baseline"):
52
  """
53
  Evaluate text classification for climate disinformation detection.
54
 
 
93
  # Make random predictions (placeholder for actual model inference)
94
  true_labels = test_dataset["label"]
95
 
96
+ if model == "baseline":
97
+ predictions = baseline_model(len(true_labels))
98
+ elif model == "tfidf_xgb":
99
+ predictions = tree_classifier(test_dataset, model='tfidf_xgb')
100
 
101
  #--------------------------------------------------------------------------------------------
102
  # YOUR MODEL INFERENCE STOPS HERE
tasks/text_models/.gitattributes DELETED
@@ -1 +0,0 @@
1
- xgb_pipeline.skops filter=lfs diff=lfs merge=lfs -text
 
 
tasks/text_models/xgb_pipeline.skops DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6c2100f08f614713cd3e19f06e3456f32ef3d3bb23ce4ff2902688c8074bb82e
3
- size 3277312
 
 
 
 
tasks/utils/text_preprocessor.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.stem import WordNetLemmatizer
3
+ from sklearn.base import BaseEstimator, TransformerMixin
4
+ import nltk
5
+ import contractions
6
+
7
+ # Download required NLTK resources
8
+ nltk.download('punkt_tab')
9
+ nltk.download('wordnet')
10
+
11
+ # Custom transformer for preprocessing text
12
+ class TextPreprocessor(BaseEstimator, TransformerMixin):
13
+ def __init__(self):
14
+ self.lemmatizer = WordNetLemmatizer()
15
+
16
+ def fit(self, X, y=None):
17
+ return self # Does nothing, just returns the instance
18
+
19
+ def transform(self, X):
20
+ preprocessed_texts = []
21
+ for doc in X:
22
+ # Expand contractions
23
+ expanded = contractions.fix(doc)
24
+ # Lowercase
25
+ lowered = expanded.lower()
26
+
27
+ # Tokenize and lemmatize
28
+ lemmatized = " ".join([self.lemmatizer.lemmatize(word) for word in nltk.word_tokenize(lowered)])
29
+ preprocessed_texts.append(lemmatized)
30
+ return preprocessed_texts