File size: 3,902 Bytes
3dad86d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification


# Define label mappings (ensure this matches the mappings used during training)
label2id = {'<negative_object>': 0, 'other': 2, '<positive_object>': 1}
id2label = {v: k for k, v in label2id.items()}

def prepare_input(tokens, tokenizer, max_length=128):
    encoding = tokenizer(
        tokens,
        is_split_into_words=True,
        return_tensors="pt",
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_offsets_mapping=True
    )
    return encoding


def split_sentence(sentence):
    # List of special tokens to preserve
    special_tokens = ['<positive_object>', '<negative_object>']
    
    # More comprehensive list of punctuation marks and symbols
    punctuation = ',.?!;:()[]{}""\'`@#$%^&*+=|\\/<>-—–'
    
    # Initialize result list and temporary word
    result = []
    current_word = ''
    i = 0
    
    while i < len(sentence):
        # Check for special tokens
        found_special = False
        for token in special_tokens:
            if sentence[i:].startswith(token):
                # Add previous word if exists
                if current_word:
                    result.append(current_word)
                    current_word = ''
                # Add special token
                result.append(token)
                i += len(token)
                found_special = True
                break
        
        if found_special:
            continue
            
        # Handle punctuation
        if sentence[i] in punctuation:
            # Add previous word if exists
            if current_word:
                result.append(current_word)
                current_word = ''
            # Add punctuation as separate token
            result.append(sentence[i])
            
        # Handle spaces
        elif sentence[i].isspace():
            if current_word:
                result.append(current_word)
                current_word = ''
                
        # Build regular words
        else:
            current_word += sentence[i]
            
        i += 1
    
    # Add final word if exists
    if current_word:
        result.append(current_word)
        
    return result

def predict(tokens, model, tokenizer, device, max_length=128):
    tokens = split_sentence(' '.join(tokens.lower().split()))

    # Prepare the input
    encoding = prepare_input(tokens, tokenizer, max_length=max_length)
    word_ids = encoding.word_ids(batch_index=0)  # List of word IDs
    
    # Move tensors to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Inference
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).cpu().numpy()[0]
    
    # Decode tokens and labels
    tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0])
    labels = [id2label.get(pred, 'O') for pred in predictions]
    
    # Align tokens with original word-level tokens
    aligned_predictions = []
    previous_word_idx = None
    for token, label, word_idx in zip(tokens_decoded, labels, word_ids):
        if word_idx is None:
            continue
        if word_idx != previous_word_idx:
            aligned_predictions.append((tokens[word_idx], label))
            previous_word_idx = word_idx
    return aligned_predictions


def load_token_classifier(pretrained_token_classifier_path, device):
    # Load tokenizer and model
    tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_token_classifier_path)
    token_classifier = DistilBertForTokenClassification.from_pretrained(pretrained_token_classifier_path)
    token_classifier.to(device)
    return token_classifier, tokenizer