laserbeam2045 commited on
Commit
d1fd8de
·
1 Parent(s): 861971b
Files changed (2) hide show
  1. app.py +29 -24
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
- from pydantic import BaseModel
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
@@ -11,43 +10,49 @@ logger = logging.getLogger(__name__)
11
  app = FastAPI()
12
 
13
  model_name = "google/gemma-2-2b-it"
 
 
 
14
  try:
15
  logger.info(f"Loading model: {model_name}")
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
17
  use_gpu = torch.cuda.is_available()
18
  logger.info(f"GPU available: {use_gpu}")
19
- quantization_config = BitsAndBytesConfig(
20
- load_in_4bit=True,
21
- bnb_4bit_compute_dtype=torch.bfloat16,
22
- bnb_4bit_quant_type="nf4",
23
- bnb_4bit_use_double_quant=True
24
- ) if use_gpu else None
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_name,
27
- torch_dtype=torch.bfloat16,
28
- device_map="auto",
29
  token=os.getenv("HF_TOKEN"),
30
- low_cpu_mem_usage=True,
31
- quantization_config=quantization_config
32
  )
33
  logger.info("Model loaded successfully")
34
  except Exception as e:
35
  logger.error(f"Model load error: {e}")
36
  raise
37
 
38
- class TextInput(BaseModel):
39
- text: str
40
- max_length: int = 50
41
-
42
- @app.post("/generate")
43
- async def generate_text(input: TextInput):
44
  try:
45
- logger.info(f"Generating text for input: {input.text}")
46
- inputs = tokenizer(input.text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
47
- outputs = model.generate(**inputs, max_length=input.max_length)
48
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  logger.info(f"Generated text: {result}")
50
- return {"generated_text": result}
51
  except Exception as e:
52
  logger.error(f"Generation error: {e}")
53
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
  import logging
6
 
7
  logging.basicConfig(level=logging.INFO)
 
10
  app = FastAPI()
11
 
12
  model_name = "google/gemma-2-2b-it"
13
+ tokenizer = None
14
+ model = None
15
+
16
  try:
17
  logger.info(f"Loading model: {model_name}")
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
19
  use_gpu = torch.cuda.is_available()
20
  logger.info(f"GPU available: {use_gpu}")
 
 
 
 
 
 
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name,
23
+ torch_dtype=torch.float16, # メモリ削減
24
+ device_map="cpu", # GPU利用不可
25
  token=os.getenv("HF_TOKEN"),
26
+ low_cpu_mem_usage=True
 
27
  )
28
  logger.info("Model loaded successfully")
29
  except Exception as e:
30
  logger.error(f"Model load error: {e}")
31
  raise
32
 
33
+ def generate_text(text, max_length=50):
 
 
 
 
 
34
  try:
35
+ logger.info(f"Generating text for input: {text}")
36
+ inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cpu")
37
+ outputs = model.generate(**inputs, max_length=max_length)
38
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
  logger.info(f"Generated text: {result}")
40
+ return result
41
  except Exception as e:
42
  logger.error(f"Generation error: {e}")
43
+ return f"Error: {str(e)}"
44
+
45
+ iface = gr.Interface(
46
+ fn=generate_text,
47
+ inputs=[gr.Textbox(label="Input Text"), gr.Slider(10, 100, value=50, label="Max Length")],
48
+ outputs=gr.Textbox(label="Generated Text"),
49
+ title="Gemma 2 API"
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ try:
54
+ logger.info("Launching Gradio interface")
55
+ iface.launch(server_name="0.0.0.0", server_port=8080)
56
+ except Exception as e:
57
+ logger.error(f"Gradio launch error: {e}")
58
+ raise
requirements.txt CHANGED
@@ -6,3 +6,4 @@ bitsandbytes==0.42.0
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6
 
 
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6
9
+ gradio==4.15.0