MATRIX / app.py
laserbeam2045
fix
7f80d8c
raw
history blame
2.75 kB
# app.py
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# 設定
# -----------------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
# Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# -----------------------------------------------------------------------------
# デバイス設定(Spaces の無料枠では CPU のみです)
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# トークナイザーとモデルのロード
# -----------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN,
torch_dtype=torch.float32, # CPU 環境では float32
device_map="auto" if torch.cuda.is_available() else None
)
model.to(device)
# -----------------------------------------------------------------------------
# FastAPI 定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT API")
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
temperature: float = 0.8
top_p: float = 0.95
@app.post("/generate")
async def generate(req: GenerationRequest):
if not req.prompt:
raise HTTPException(status_code=400, detail="prompt は必須です。")
# トークナイズ
inputs = tokenizer(
req.prompt,
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
# 生成
generation_output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
return {"generated_text": text}
# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")