Spaces:
Runtime error
Runtime error
File size: 1,523 Bytes
0b0f289 3ecf051 0b0f289 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 62a4b51 3ecf051 0b0f289 3ecf051 0b0f289 3ecf051 0b0f289 3ecf051 0b0f289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments
app = FastAPI()
# Initialize the tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# Prepare the training data
with open("cyberpunk_lore.txt", "r") as f:
train_data = f.read()
train_data = train_data.split("\n")
train_data = [tokenizer.encode(text, return_tensors="pt") for text in train_data]
# Define the training arguments
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=16,
save_steps=10_000,
save_total_limit=2,
)
# Create the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=train_data,
)
# Start the training
trainer.train()
# Save the fine-tuned model
trainer.save_model('./results')
# Load the fine-tuned model
model = trainer.get_model()
# Create the inference endpoint
@app.post("/infer")
def infer(input: str):
input_ids = tokenizer.encode(input, return_tensors="pt")
output = model(input_ids)[0]
return {"output": output}
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html") |