laserbeam2045 commited on
Commit
2fc7e1b
·
1 Parent(s): 0af76ae
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -1,29 +1,46 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration
3
- import torch
4
  import os
 
 
 
 
 
5
 
6
- print(os.getenv("HF_TOKEN"))
 
 
7
 
8
- # モデルロード
9
- model_name = "unsloth/gemma-3-4b-it"
10
- processor = AutoProcessor.from_pretrained(model_name)
11
- model = Gemma3ForConditionalGeneration.from_pretrained(
12
- model_name, torch_dtype=torch.bfloat16, device_map="auto"
13
- )
14
 
15
- def generate_text(text, max_length=50):
16
- inputs = processor(text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
17
- outputs = model.generate(**inputs, max_length=max_length)
18
- return processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Gradioインターフェース
21
- iface = gr.Interface(
22
- fn=generate_text,
23
- inputs=["text", "slider"],
24
- outputs="text",
25
- title="Gemma 3 API"
26
- )
27
 
28
- if __name__ == "__main__":
29
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ from fastapi import FastAPI
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+ from pydantic import BaseModel
6
+ import logging
7
 
8
+ # ログ設定
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ app = FastAPI()
 
 
 
 
 
13
 
14
+ # モデルロード
15
+ model_name = "google/gemma-3-4b-it"
16
+ try:
17
+ logger.info(f"Loading model: {model_name}")
18
+ processor = AutoProcessor.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.bfloat16,
22
+ device_map="auto",
23
+ token=os.getenv("HF_TOKEN"),
24
+ low_cpu_mem_usage=True
25
+ )
26
+ logger.info("Model loaded successfully")
27
+ except Exception as e:
28
+ logger.error(f"Model load error: {e}")
29
+ raise
30
 
31
+ class TextInput(BaseModel):
32
+ text: str
33
+ max_length: int = 50
 
 
 
 
34
 
35
+ @app.post("/generate")
36
+ async def generate_text(input: TextInput):
37
+ try:
38
+ logger.info(f"Generating text for input: {input.text}")
39
+ inputs = processor(input.text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
40
+ outputs = model.generate(**inputs, max_length=input.max_length)
41
+ result = processor.decode(outputs[0], skip_special_tokens=True)
42
+ logger.info(f"Generated text: {result}")
43
+ return {"generated_text": result}
44
+ except Exception as e:
45
+ logger.error(f"Generation error: {e}")
46
+ return {"error": str(e)}