Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments | |
app = FastAPI() | |
# Initialize the tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
model = BertForMaskedLM.from_pretrained("bert-base-uncased") | |
# Prepare the training data | |
with open("cyberpunk_lore.txt", "r") as f: | |
train_data = f.read() | |
train_data = train_data.split("\n") | |
train_data = [tokenizer.encode(text, return_tensors="pt") for text in train_data] | |
# Define the training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
per_device_train_batch_size=16, | |
save_steps=10_000, | |
save_total_limit=2, | |
) | |
# Create the trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_data, | |
eval_dataset=train_data, | |
) | |
# Start the training | |
trainer.train() | |
# Save the fine-tuned model | |
trainer.save_model('./results') | |
# Load the fine-tuned model | |
model = trainer.get_model() | |
# Create the inference endpoint | |
def infer(input: str): | |
input_ids = tokenizer.encode(input, return_tensors="pt") | |
output = model(input_ids)[0] | |
return {"output": output} | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") |