laserbeam2045 commited on
Commit
a751c84
·
1 Parent(s): af0df21
Files changed (2) hide show
  1. app.py +28 -19
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,47 +1,56 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from pydantic import BaseModel
6
  import logging
7
 
8
- # ログ設定
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- app = FastAPI()
13
-
14
- # モデルロード
15
  model_name = "google/gemma-2-2b-it"
16
  try:
17
  logger.info(f"Loading model: {model_name}")
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
 
 
 
 
 
 
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
  torch_dtype=torch.bfloat16,
22
  device_map="auto",
23
  token=os.getenv("HF_TOKEN"),
24
  low_cpu_mem_usage=True,
25
- load_in_4bit=True
26
  )
27
  logger.info("Model loaded successfully")
28
  except Exception as e:
29
  logger.error(f"Model load error: {e}")
30
  raise
31
 
32
- class TextInput(BaseModel):
33
- text: str
34
- max_length: int = 50
35
-
36
- @app.post("/generate")
37
- async def generate_text(input: TextInput):
38
  try:
39
- logger.info(f"Generating text for input: {input.text}")
40
- inputs = tokenizer(input.text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
41
- outputs = model.generate(**inputs, max_length=input.max_length)
42
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
  logger.info(f"Generated text: {result}")
44
- return {"generated_text": result}
45
  except Exception as e:
46
  logger.error(f"Generation error: {e}")
47
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
5
  import logging
6
 
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
 
 
 
10
  model_name = "google/gemma-2-2b-it"
11
  try:
12
  logger.info(f"Loading model: {model_name}")
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
14
+ use_gpu = torch.cuda.is_available()
15
+ logger.info(f"GPU available: {use_gpu}")
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ bnb_4bit_compute_dtype=torch.bfloat16,
19
+ bnb_4bit_quant_type="nf4",
20
+ bnb_4bit_use_double_quant=True
21
+ ) if use_gpu else None
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_name,
24
  torch_dtype=torch.bfloat16,
25
  device_map="auto",
26
  token=os.getenv("HF_TOKEN"),
27
  low_cpu_mem_usage=True,
28
+ quantization_config=quantization_config
29
  )
30
  logger.info("Model loaded successfully")
31
  except Exception as e:
32
  logger.error(f"Model load error: {e}")
33
  raise
34
 
35
+ def generate_text(text, max_length=50):
 
 
 
 
 
36
  try:
37
+ logger.info(f"Generating text for input: {text}")
38
+ inputs = tokenizer(text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
39
+ outputs = model.generate(**inputs, max_length=max_length)
40
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
  logger.info(f"Generated text: {result}")
42
+ return result
43
  except Exception as e:
44
  logger.error(f"Generation error: {e}")
45
+ return f"Error: {str(e)}"
46
+
47
+ iface = gr.Interface(
48
+ fn=generate_text,
49
+ inputs=[gr.Textbox(label="Input Text"), gr.Slider(10, 100, value=50, label="Max Length")],
50
+ outputs=gr.Textbox(label="Generated Text"),
51
+ title="Gemma 2 API"
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ logger.info("Launching Gradio interface")
56
+ iface.launch(server_name="0.0.0.0", server_port=8080)
requirements.txt CHANGED
@@ -6,3 +6,4 @@ bitsandbytes==0.42.0
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6
 
 
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6
9
+ gradio==4.15.0