fastapi_t5 / main.py
streetyogi's picture
Update main.py
3ecf051
raw
history blame
1.52 kB
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments
app = FastAPI()
# Initialize the tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# Prepare the training data
with open("cyberpunk_lore.txt", "r") as f:
train_data = f.read()
train_data = train_data.split("\n")
train_data = [tokenizer.encode(text, return_tensors="pt") for text in train_data]
# Define the training arguments
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=16,
save_steps=10_000,
save_total_limit=2,
)
# Create the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=train_data,
)
# Start the training
trainer.train()
# Save the fine-tuned model
trainer.save_model('./results')
# Load the fine-tuned model
model = trainer.get_model()
# Create the inference endpoint
@app.post("/infer")
def infer(input: str):
input_ids = tokenizer.encode(input, return_tensors="pt")
output = model(input_ids)[0]
return {"output": output}
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")