safinal commited on
Commit
3dad86d
Β·
verified Β·
1 Parent(s): 3c06693

Create token_classifier.py

Browse files
Files changed (1) hide show
  1. token_classifier.py +118 -0
token_classifier.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification
3
+
4
+
5
+ # Define label mappings (ensure this matches the mappings used during training)
6
+ label2id = {'<negative_object>': 0, 'other': 2, '<positive_object>': 1}
7
+ id2label = {v: k for k, v in label2id.items()}
8
+
9
+ def prepare_input(tokens, tokenizer, max_length=128):
10
+ encoding = tokenizer(
11
+ tokens,
12
+ is_split_into_words=True,
13
+ return_tensors="pt",
14
+ padding='max_length',
15
+ truncation=True,
16
+ max_length=max_length,
17
+ return_offsets_mapping=True
18
+ )
19
+ return encoding
20
+
21
+
22
+ def split_sentence(sentence):
23
+ # List of special tokens to preserve
24
+ special_tokens = ['<positive_object>', '<negative_object>']
25
+
26
+ # More comprehensive list of punctuation marks and symbols
27
+ punctuation = ',.?!;:()[]{}""\'`@#$%^&*+=|\\/<>-—–'
28
+
29
+ # Initialize result list and temporary word
30
+ result = []
31
+ current_word = ''
32
+ i = 0
33
+
34
+ while i < len(sentence):
35
+ # Check for special tokens
36
+ found_special = False
37
+ for token in special_tokens:
38
+ if sentence[i:].startswith(token):
39
+ # Add previous word if exists
40
+ if current_word:
41
+ result.append(current_word)
42
+ current_word = ''
43
+ # Add special token
44
+ result.append(token)
45
+ i += len(token)
46
+ found_special = True
47
+ break
48
+
49
+ if found_special:
50
+ continue
51
+
52
+ # Handle punctuation
53
+ if sentence[i] in punctuation:
54
+ # Add previous word if exists
55
+ if current_word:
56
+ result.append(current_word)
57
+ current_word = ''
58
+ # Add punctuation as separate token
59
+ result.append(sentence[i])
60
+
61
+ # Handle spaces
62
+ elif sentence[i].isspace():
63
+ if current_word:
64
+ result.append(current_word)
65
+ current_word = ''
66
+
67
+ # Build regular words
68
+ else:
69
+ current_word += sentence[i]
70
+
71
+ i += 1
72
+
73
+ # Add final word if exists
74
+ if current_word:
75
+ result.append(current_word)
76
+
77
+ return result
78
+
79
+ def predict(tokens, model, tokenizer, device, max_length=128):
80
+ tokens = split_sentence(' '.join(tokens.lower().split()))
81
+
82
+ # Prepare the input
83
+ encoding = prepare_input(tokens, tokenizer, max_length=max_length)
84
+ word_ids = encoding.word_ids(batch_index=0) # List of word IDs
85
+
86
+ # Move tensors to device
87
+ input_ids = encoding['input_ids'].to(device)
88
+ attention_mask = encoding['attention_mask'].to(device)
89
+
90
+ # Inference
91
+ with torch.no_grad():
92
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
93
+
94
+ logits = outputs.logits
95
+ predictions = torch.argmax(logits, dim=-1).cpu().numpy()[0]
96
+
97
+ # Decode tokens and labels
98
+ tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0])
99
+ labels = [id2label.get(pred, 'O') for pred in predictions]
100
+
101
+ # Align tokens with original word-level tokens
102
+ aligned_predictions = []
103
+ previous_word_idx = None
104
+ for token, label, word_idx in zip(tokens_decoded, labels, word_ids):
105
+ if word_idx is None:
106
+ continue
107
+ if word_idx != previous_word_idx:
108
+ aligned_predictions.append((tokens[word_idx], label))
109
+ previous_word_idx = word_idx
110
+ return aligned_predictions
111
+
112
+
113
+ def load_token_classifier(pretrained_token_classifier_path, device):
114
+ # Load tokenizer and model
115
+ tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_token_classifier_path)
116
+ token_classifier = DistilBertForTokenClassification.from_pretrained(pretrained_token_classifier_path)
117
+ token_classifier.to(device)
118
+ return token_classifier, tokenizer