from fastapi import FastAPI, HTTPException from transformers import AutoModelForCausalLM, AutoTokenizer import torch import logging from pydantic import BaseModel import os import tarfile logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Debug environment variables logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) app = FastAPI() model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" model_path = "/app/granite-8b-finetuned-ascii" # Extract tarball if model directory doesn't exist if not os.path.exists(model_path): logger.info(f"Extracting model tarball: {model_tarball}") try: with tarfile.open(model_tarball, "r:gz") as tar: tar.extractall(path="/app") logger.info("Model tarball extracted successfully") except Exception as e: logger.error(f"Failed to extract model tarball: {str(e)}") raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}") try: logger.info("Loading tokenizer and model") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.padding_side = 'right' model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {str(e)}") raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}") class EditRequest(BaseModel): text: str @app.get("/") def greet_json(): return {"status": "Model is ready", "model": model_path} @app.post("/generate") async def generate(request: EditRequest): try: prompt = f"Edit this AsciiDoc sentence: {request.text}" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_length=200) response = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"Generated response for prompt: {prompt}") return {"response": response} except Exception as e: logger.error(f"Generation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")