Spaces:
Running
Running
File size: 1,988 Bytes
ae51d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from fastapi import FastAPI
from pydantic import BaseModel
from hypothesis import BaseModelHypothesis
from randomforest import RandomForestDependencies
import torch.nn as nn
import torch
class AlbertCustomClassificationHead(nn.Module):
def __init__(self, albert_model, dropout_rate=0.1):
super(AlbertCustomClassificationHead, self).__init__()
self.albert_model = albert_model
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(1024 + 25, 1)
def forward(self, input_ids, attention_mask, additional_features, labels=None):
albert_output = self.albert_model(
input_ids=input_ids, attention_mask=attention_mask).pooler_output
combined_features = torch.cat(
[albert_output, additional_features], dim=1)
dropout_output = self.dropout(combined_features)
logits = self.classifier(dropout_output)
if labels is not None:
loss_fn = nn.BCEWithLogitsLoss()
labels = labels.unsqueeze(1)
loss = loss_fn(logits, labels.float())
return logits, loss
else:
return logits
app = FastAPI()
class PredictRequest(BaseModel):
question: str
answer: str
backspace_count: int
typing_duration: int
letter_click_counts: dict[str, int]
@app.post("/predict")
async def predict(request: PredictRequest):
request_dict = request.model_dump()
question = request_dict.get("question")
answer = request_dict.get("answer")
backspace_count = request_dict.get("backspace_count")
typing_duration = request_dict.get("typing_duration")
letter_click_counts = request_dict.get("letter_click_counts")
hypothesis = BaseModelHypothesis()
features_normalized_text_length = hypothesis.calculate_normalized_text_length_features(
answer)
features_not_normalized = hypothesis.calculate_not_normalized_features(
answer)
return request_dict.get("backspace_count")
|