Spaces:
Running
Running
# app.py | |
import os | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# ----------------------------------------------------------------------------- | |
# 設定 | |
# ----------------------------------------------------------------------------- | |
MODEL_ID = "google/gemma-3-4b-it" | |
# Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# ----------------------------------------------------------------------------- | |
# デバイス設定(Spaces の無料枠では CPU のみです) | |
# ----------------------------------------------------------------------------- | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ----------------------------------------------------------------------------- | |
# トークナイザーとモデルのロード | |
# ----------------------------------------------------------------------------- | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
use_auth_token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
use_auth_token=HF_TOKEN, | |
torch_dtype=torch.float32, # CPU 環境では float32 | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
model.to(device) | |
# ----------------------------------------------------------------------------- | |
# FastAPI 定義 | |
# ----------------------------------------------------------------------------- | |
app = FastAPI(title="Gemma3-4B-IT API") | |
class GenerationRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 128 | |
temperature: float = 0.8 | |
top_p: float = 0.95 | |
async def generate(req: GenerationRequest): | |
if not req.prompt: | |
raise HTTPException(status_code=400, detail="prompt は必須です。") | |
# トークナイズ | |
inputs = tokenizer( | |
req.prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
).to(device) | |
# 生成 | |
generation_output = model.generate( | |
**inputs, | |
max_new_tokens=req.max_new_tokens, | |
temperature=req.temperature, | |
top_p=req.top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
text = tokenizer.decode(generation_output[0], skip_special_tokens=True) | |
return {"generated_text": text} | |
# ----------------------------------------------------------------------------- | |
# ローカル起動用 | |
# ----------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |