File size: 956 Bytes
fbdcb75
78047da
fbdcb75
 
78047da
151d72b
1d29ee7
 
 
 
 
fbdcb75
1d29ee7
 
fbdcb75
1d29ee7
 
 
 
 
 
fbdcb75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import BertTokenizer, BertForSequenceClassification

tokenizer = BertTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')

def generate_answers(query):
    # Garantindo que a query é uma string
    if not isinstance(query, str):
        raise ValueError("A entrada para a função generate_answers deve ser uma string.")
    
    # Tokenização
    inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
    
    # Realizando a predição
    outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=1).item()  # Converter tensor para um inteiro
    
    # Labels devem corresponder ao número de classes do modelo
    labels = ['ds', 'real', 'Group']
    predicted_label = labels[prediction]  # Usando o índice para acessar a label
    
    return predicted_label