Spaces:
Running
Running
# app.py | |
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from huggingface_hub import hf_hub_download | |
from pyllamacpp.model import Model | |
# ----------------------------------------------------------------------------- | |
# Hugging Face Hub の設定 | |
# ----------------------------------------------------------------------------- | |
HF_TOKEN = os.environ.get("HF_TOKEN") # 必要に応じて Secrets にセット | |
REPO_ID = "google/gemma-3-12b-it-qat-q4_0-gguf" | |
# 実際にリポジトリに置かれている GGUF ファイル名を確認してください。 | |
# 例: "gemma-3-12b-it-qat-q4_0-gguf.gguf" | |
GGUF_FILENAME = "gemma-3-12b-it-qat-q4_0-gguf.gguf" | |
# キャッシュ先のパス(リポジトリ直下に置く場合) | |
MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME) | |
# ----------------------------------------------------------------------------- | |
# 起動時に一度だけダウンロード | |
# ----------------------------------------------------------------------------- | |
if not os.path.exists(MODEL_PATH): | |
print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …") | |
hf_hub_download( | |
repo_id=REPO_ID, | |
filename=GGUF_FILENAME, | |
token=HF_TOKEN, | |
repo_type="model", # 明示的にモデルリポジトリを指定 | |
local_dir=os.getcwd(), # カレントディレクトリに保存 | |
local_dir_use_symlinks=False | |
) | |
# ----------------------------------------------------------------------------- | |
# llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード | |
# ----------------------------------------------------------------------------- | |
llm = Model( | |
model_path=MODEL_PATH, | |
n_ctx=512, # 必要に応じて調整 | |
n_threads=4, # 実マシンのコア数に合わせて | |
) | |
# ----------------------------------------------------------------------------- | |
# FastAPI 定義 | |
# ----------------------------------------------------------------------------- | |
app = FastAPI(title="Gemma3-12B-IT Q4_0 GGUF API") | |
class GenerationRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 128 | |
temperature: float = 0.8 | |
top_p: float = 0.95 | |
async def generate(req: GenerationRequest): | |
if not req.prompt: | |
raise HTTPException(status_code=400, detail="`prompt` は必須です。") | |
# llama.cpp の generate を呼び出し | |
text = llm.generate( | |
req.prompt, | |
top_p=req.top_p, | |
temp=req.temperature, | |
n_predict=req.max_new_tokens, | |
repeat_last_n=64, | |
repeat_penalty=1.1 | |
) | |
return {"generated_text": text} | |
# ----------------------------------------------------------------------------- | |
# ローカル起動用 | |
# ----------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |