File size: 4,526 Bytes
7f80d8c
dea3a07
7f80d8c
115a37b
6dd176e
9d3ba14
7f80d8c
 
6dd176e
7f80d8c
6dd176e
d31307a
8828cda
d31307a
 
7f80d8c
6dd176e
 
e604a26
7f80d8c
6dd176e
7f80d8c
6dd176e
 
 
 
 
 
 
 
 
 
7f80d8c
9736832
9d3ba14
9736832
9d3ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f80d8c
 
6dd176e
7f80d8c
8828cda
7f80d8c
9fa7b35
 
 
 
7f80d8c
 
 
 
 
9d3ba14
 
 
115a37b
 
7f80d8c
 
9736832
9d3ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f80d8c
c6f721c
 
 
 
 
59b6b5d
 
 
c6f721c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# app.py
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama # llama-cpp-python をインポート

# -----------------------------------------------------------------------------
# Hugging Face Hub の設定
# -----------------------------------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN")  # 必要に応じて Secrets にセット
REPO_ID  = "google/gemma-3-4b-it-qat-q4_0-gguf"
# 実際にリポジトリに置かれている GGUF ファイル名を確認してください。
# 例: "gemma-3-4b-it-qat-q4_0-gguf.gguf"
GGUF_FILENAME = "gemma-3-4b-it-q4_0.gguf"

# キャッシュ先のパス(リポジトリ直下に置く場合)
MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME)

# -----------------------------------------------------------------------------
# 起動時に一度だけダウンロード
# -----------------------------------------------------------------------------
if not os.path.exists(MODEL_PATH):
    print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …")
    hf_hub_download(
        repo_id=REPO_ID,
        filename=GGUF_FILENAME,
        token=HF_TOKEN,
        repo_type="model",        # 明示的にモデルリポジトリを指定
        local_dir=os.getcwd(),    # カレントディレクトリに保存
        local_dir_use_symlinks=False
    )

# -----------------------------------------------------------------------------
# llama-cpp-python で 4bit GGUF モデルをロード
# -----------------------------------------------------------------------------
print(f"Loading model from {MODEL_PATH}...")
try:
    llm = Llama(
        model_path=MODEL_PATH,
        n_ctx=2048,      # コンテキストサイズ (モデルに合わせて調整してください)
        # n_gpu_layers=-1, # GPU を使う場合 (Hugging Face Spaces 無料枠では通常 0)
        n_gpu_layers=0,   # CPU のみ使用
        verbose=True     # 詳細ログを出力
    )
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    # エラーが発生した場合、アプリケーションを終了させるか、エラーハンドリングを行う
    # ここでは簡単なエラーメッセージを出力して終了する例
    raise RuntimeError(f"Failed to load the LLM model: {e}")


# -----------------------------------------------------------------------------
# FastAPI 定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT Q4_0 GGUF API")

@app.get("/")
async def read_root():
    return {"message": "Gemma3 API is running"}

class GenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 128
    temperature: float = 0.8
    top_p: float = 0.95
    # llama-cpp-python で利用可能な他のパラメータも追加可能
    # stop: list[str] | None = None
    # repeat_penalty: float = 1.1

@app.post("/generate")
async def generate(req: GenerationRequest):
    if not req.prompt:
        raise HTTPException(status_code=400, detail="`prompt` は必須です。")

    try:
        # llama-cpp-python の __call__ メソッドで生成
        output = llm(
            req.prompt,
            max_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            # stop=req.stop, # 必要なら追加
            # repeat_penalty=req.repeat_penalty, # 必要なら追加
        )
        # 生成されたテキストを取得
        generated_text = output["choices"][0]["text"]
        return {"generated_text": generated_text}
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"生成中にエラーが発生しました: {e}")


# -----------------------------------------------------------------------------
# Uvicorn サーバーの起動 (Hugging Face Spaces 用)
# -----------------------------------------------------------------------------
# if __name__ == "__main__": ガードは付けずに直接実行
import uvicorn
# Hugging Face Spaces で標準的に使用されるポート 7860 を明示的に指定
port = 7860
# port = int(os.environ.get("PORT", 7860)) # 環境変数があればそれを使う方が良い場合もある
# host="0.0.0.0" でコンテナ外からのアクセスを許可
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")