mjarrett
updated for 8B model
29969bf
raw
history blame
2.38 kB
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from pydantic import BaseModel
import os
import tarfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Debug environment variables
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})
app = FastAPI()
model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
model_path = "/app/granite-8b-finetuned-ascii"
# Extract tarball if model directory doesn't exist
if not os.path.exists(model_path):
logger.info(f"Extracting model tarball: {model_tarball}")
try:
with tarfile.open(model_tarball, "r:gz") as tar:
tar.extractall(path="/app")
logger.info("Model tarball extracted successfully")
except Exception as e:
logger.error(f"Failed to extract model tarball: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}")
try:
logger.info("Loading tokenizer and model")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'right'
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}")
class EditRequest(BaseModel):
text: str
@app.get("/")
def greet_json():
return {"status": "Model is ready", "model": model_path}
@app.post("/generate")
async def generate(request: EditRequest):
try:
prompt = f"Edit this AsciiDoc sentence: {request.text}"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=200)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated response for prompt: {prompt}")
return {"response": response}
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")