laserbeam2045 commited on
Commit
9d3ba14
·
1 Parent(s): 215bcb0
Files changed (2) hide show
  1. app.py +40 -15
  2. requirements.txt +1 -1
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from huggingface_hub import hf_hub_download
6
- from pyllamacpp.model import Model
7
 
8
  # -----------------------------------------------------------------------------
9
  # Hugging Face Hub の設定
@@ -32,11 +32,24 @@ if not os.path.exists(MODEL_PATH):
32
  )
33
 
34
  # -----------------------------------------------------------------------------
35
- # llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード
36
  # -----------------------------------------------------------------------------
37
- llm = Model(
38
- model_path=MODEL_PATH,
39
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # -----------------------------------------------------------------------------
42
  # FastAPI 定義
@@ -48,21 +61,32 @@ class GenerationRequest(BaseModel):
48
  max_new_tokens: int = 128
49
  temperature: float = 0.8
50
  top_p: float = 0.95
 
 
 
51
 
52
  @app.post("/generate")
53
  async def generate(req: GenerationRequest):
54
  if not req.prompt:
55
  raise HTTPException(status_code=400, detail="`prompt` は必須です。")
56
- # llama.cpp の generate を呼び出し
57
- text = llm.generate(
58
- req.prompt,
59
- top_p=req.top_p,
60
- temp=req.temperature,
61
- n_predict=req.max_new_tokens,
62
- repeat_last_n=64,
63
- repeat_penalty=1.1
64
- )
65
- return {"generated_text": text}
 
 
 
 
 
 
 
 
66
 
67
  # -----------------------------------------------------------------------------
68
  # ローカル起動用
@@ -70,4 +94,5 @@ async def generate(req: GenerationRequest):
70
  if __name__ == "__main__":
71
  import uvicorn
72
  port = int(os.environ.get("PORT", 8000))
 
73
  uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from huggingface_hub import hf_hub_download
6
+ from llama_cpp import Llama # llama-cpp-python をインポート
7
 
8
  # -----------------------------------------------------------------------------
9
  # Hugging Face Hub の設定
 
32
  )
33
 
34
  # -----------------------------------------------------------------------------
35
+ # llama-cpp-python で 4bit GGUF モデルをロード
36
  # -----------------------------------------------------------------------------
37
+ print(f"Loading model from {MODEL_PATH}...")
38
+ try:
39
+ llm = Llama(
40
+ model_path=MODEL_PATH,
41
+ n_ctx=2048, # コンテキストサイズ (モデルに合わせて調整してください)
42
+ # n_gpu_layers=-1, # GPU を使う場合 (Hugging Face Spaces 無料枠では通常 0)
43
+ n_gpu_layers=0, # CPU のみ使用
44
+ verbose=True # 詳細ログを出力
45
+ )
46
+ print("Model loaded successfully.")
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ # エラーが発生した場合、アプリケーションを終了させるか、エラーハンドリングを行う
50
+ # ここでは簡単なエラーメッセージを出力して終了する例
51
+ raise RuntimeError(f"Failed to load the LLM model: {e}")
52
+
53
 
54
  # -----------------------------------------------------------------------------
55
  # FastAPI 定義
 
61
  max_new_tokens: int = 128
62
  temperature: float = 0.8
63
  top_p: float = 0.95
64
+ # llama-cpp-python で利用可能な他のパラメータも追加可能
65
+ # stop: list[str] | None = None
66
+ # repeat_penalty: float = 1.1
67
 
68
  @app.post("/generate")
69
  async def generate(req: GenerationRequest):
70
  if not req.prompt:
71
  raise HTTPException(status_code=400, detail="`prompt` は必須です。")
72
+
73
+ try:
74
+ # llama-cpp-python の __call__ メソッドで生成
75
+ output = llm(
76
+ req.prompt,
77
+ max_tokens=req.max_new_tokens,
78
+ temperature=req.temperature,
79
+ top_p=req.top_p,
80
+ # stop=req.stop, # 必要なら追加
81
+ # repeat_penalty=req.repeat_penalty, # 必要なら追加
82
+ )
83
+ # 生成されたテキストを取得
84
+ generated_text = output["choices"][0]["text"]
85
+ return {"generated_text": generated_text}
86
+ except Exception as e:
87
+ print(f"Error during generation: {e}")
88
+ raise HTTPException(status_code=500, detail=f"生成中にエラーが発生しました: {e}")
89
+
90
 
91
  # -----------------------------------------------------------------------------
92
  # ローカル起動用
 
94
  if __name__ == "__main__":
95
  import uvicorn
96
  port = int(os.environ.get("PORT", 8000))
97
+ # アプリケーションのロードに失敗した場合に備えて try-except を追加することも検討
98
  uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  fastapi
2
  uvicorn[standard]
3
- pyllamacpp
 
1
  fastapi
2
  uvicorn[standard]
3
+ llama-cpp-python