Spaces:
Running
Running
File size: 2,754 Bytes
7f80d8c dea3a07 2fc7e1b 7f80d8c 115a37b 7f80d8c e604a26 7f80d8c 115a37b 7f80d8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# app.py
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# 設定
# -----------------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
# Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# -----------------------------------------------------------------------------
# デバイス設定(Spaces の無料枠では CPU のみです)
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# トークナイザーとモデルのロード
# -----------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN,
torch_dtype=torch.float32, # CPU 環境では float32
device_map="auto" if torch.cuda.is_available() else None
)
model.to(device)
# -----------------------------------------------------------------------------
# FastAPI 定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT API")
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
temperature: float = 0.8
top_p: float = 0.95
@app.post("/generate")
async def generate(req: GenerationRequest):
if not req.prompt:
raise HTTPException(status_code=400, detail="prompt は必須です。")
# トークナイズ
inputs = tokenizer(
req.prompt,
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
# 生成
generation_output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
return {"generated_text": text}
# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|