# app.py import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from huggingface_hub import hf_hub_download from pyllamacpp.model import Model # ----------------------------------------------------------------------------- # Hugging Face Hub の設定 # ----------------------------------------------------------------------------- HF_TOKEN = os.environ.get("HF_TOKEN") # 必要に応じて Secrets にセット REPO_ID = "google/gemma-3-12b-it-qat-q4_0-gguf" # 実際にリポジトリに置かれている GGUF ファイル名を確認してください。 # 例: "gemma-3-12b-it-qat-q4_0-gguf.gguf" GGUF_FILENAME = "gemma-3-12b-it-q4_0.gguf" # キャッシュ先のパス(リポジトリ直下に置く場合) MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME) # ----------------------------------------------------------------------------- # 起動時に一度だけダウンロード # ----------------------------------------------------------------------------- if not os.path.exists(MODEL_PATH): print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …") hf_hub_download( repo_id=REPO_ID, filename=GGUF_FILENAME, token=HF_TOKEN, repo_type="model", # 明示的にモデルリポジトリを指定 local_dir=os.getcwd(), # カレントディレクトリに保存 local_dir_use_symlinks=False ) # ----------------------------------------------------------------------------- # llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード # ----------------------------------------------------------------------------- llm = Model( model_path=MODEL_PATH, n_ctx=512, # 必要に応じて調整 n_threads=4, # 実マシンのコア数に合わせて ) # ----------------------------------------------------------------------------- # FastAPI 定義 # ----------------------------------------------------------------------------- app = FastAPI(title="Gemma3-12B-IT Q4_0 GGUF API") class GenerationRequest(BaseModel): prompt: str max_new_tokens: int = 128 temperature: float = 0.8 top_p: float = 0.95 @app.post("/generate") async def generate(req: GenerationRequest): if not req.prompt: raise HTTPException(status_code=400, detail="`prompt` は必須です。") # llama.cpp の generate を呼び出し text = llm.generate( req.prompt, top_p=req.top_p, temp=req.temperature, n_predict=req.max_new_tokens, repeat_last_n=64, repeat_penalty=1.1 ) return {"generated_text": text} # ----------------------------------------------------------------------------- # ローカル起動用 # ----------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 8000)) uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")