File size: 1,988 Bytes
ae51d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from fastapi import FastAPI
from pydantic import BaseModel
from hypothesis import BaseModelHypothesis
from randomforest import RandomForestDependencies
import torch.nn as nn
import torch


class AlbertCustomClassificationHead(nn.Module):
    def __init__(self, albert_model, dropout_rate=0.1):
        super(AlbertCustomClassificationHead, self).__init__()
        self.albert_model = albert_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(1024 + 25, 1)

    def forward(self, input_ids, attention_mask, additional_features, labels=None):
        albert_output = self.albert_model(
            input_ids=input_ids, attention_mask=attention_mask).pooler_output

        combined_features = torch.cat(
            [albert_output, additional_features], dim=1)

        dropout_output = self.dropout(combined_features)

        logits = self.classifier(dropout_output)

        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            labels = labels.unsqueeze(1)
            loss = loss_fn(logits, labels.float())
            return logits, loss
        else:
            return logits


app = FastAPI()


class PredictRequest(BaseModel):
    question: str
    answer: str
    backspace_count: int
    typing_duration: int
    letter_click_counts: dict[str, int]


@app.post("/predict")
async def predict(request: PredictRequest):
    request_dict = request.model_dump()

    question = request_dict.get("question")
    answer = request_dict.get("answer")
    backspace_count = request_dict.get("backspace_count")
    typing_duration = request_dict.get("typing_duration")
    letter_click_counts = request_dict.get("letter_click_counts")

    hypothesis = BaseModelHypothesis()
    features_normalized_text_length = hypothesis.calculate_normalized_text_length_features(
        answer)
    features_not_normalized = hypothesis.calculate_not_normalized_features(
        answer)

    return request_dict.get("backspace_count")