Spaces:
Running
Running
# app.py | |
import os | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# ----------------------------------------------------------------------------- | |
# 設定 | |
# ----------------------------------------------------------------------------- | |
MODEL_ID = "google/gemma-3-4b-it" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
DEVICE = torch.device("cpu") # 無料枠は CPU のみ | |
# ----------------------------------------------------------------------------- | |
# トークナイザーのロード | |
# ----------------------------------------------------------------------------- | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
# ----------------------------------------------------------------------------- | |
# モデルのロード+低メモリモード | |
# ----------------------------------------------------------------------------- | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
).to(DEVICE) | |
# ----------------------------------------------------------------------------- | |
# 動的量子化の適用 | |
# ----------------------------------------------------------------------------- | |
# - {torch.nn.Linear} を INT8 化 | |
# - dtype=torch.qint8 で重みのみ量子化 | |
model = torch.quantization.quantize_dynamic( | |
model, | |
{torch.nn.Linear}, | |
dtype=torch.qint8 | |
) | |
# ----------------------------------------------------------------------------- | |
# FastAPI サーバー定義 | |
# ----------------------------------------------------------------------------- | |
app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization") | |
class GenerationRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 128 | |
temperature: float = 0.8 | |
top_p: float = 0.95 | |
async def generate(req: GenerationRequest): | |
if not req.prompt: | |
raise HTTPException(status_code=400, detail="`prompt` は必須です。") | |
# トークナイズして推論 | |
inputs = tokenizer( | |
req.prompt, | |
return_tensors="pt", | |
truncation=True, | |
padding=True | |
).to(DEVICE) | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=req.max_new_tokens, | |
temperature=req.temperature, | |
top_p=req.top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return {"generated_text": text} | |
# ----------------------------------------------------------------------------- | |
# ローカル起動用 | |
# ----------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |