laserbeam2045 commited on
Commit
115a37b
·
1 Parent(s): d1fd8de
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import torch
3
- import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
  import logging
6
 
7
  logging.basicConfig(level=logging.INFO)
@@ -16,8 +17,6 @@ model = None
16
  try:
17
  logger.info(f"Loading model: {model_name}")
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
19
- use_gpu = torch.cuda.is_available()
20
- logger.info(f"GPU available: {use_gpu}")
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name,
23
  torch_dtype=torch.float16, # メモリ削減
@@ -30,29 +29,19 @@ except Exception as e:
30
  logger.error(f"Model load error: {e}")
31
  raise
32
 
33
- def generate_text(text, max_length=50):
 
 
 
 
 
34
  try:
35
- logger.info(f"Generating text for input: {text}")
36
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to("cpu")
37
- outputs = model.generate(**inputs, max_length=max_length)
38
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
  logger.info(f"Generated text: {result}")
40
- return result
41
  except Exception as e:
42
  logger.error(f"Generation error: {e}")
43
- return f"Error: {str(e)}"
44
-
45
- iface = gr.Interface(
46
- fn=generate_text,
47
- inputs=[gr.Textbox(label="Input Text"), gr.Slider(10, 100, value=50, label="Max Length")],
48
- outputs=gr.Textbox(label="Generated Text"),
49
- title="Gemma 2 API"
50
- )
51
-
52
- if __name__ == "__main__":
53
- try:
54
- logger.info("Launching Gradio interface")
55
- iface.launch(server_name="0.0.0.0", server_port=8080)
56
- except Exception as e:
57
- logger.error(f"Gradio launch error: {e}")
58
- raise
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from pydantic import BaseModel
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
 
17
  try:
18
  logger.info(f"Loading model: {model_name}")
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
 
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
22
  torch_dtype=torch.float16, # メモリ削減
 
29
  logger.error(f"Model load error: {e}")
30
  raise
31
 
32
+ class TextInput(BaseModel):
33
+ text: str
34
+ max_length: int = 50
35
+
36
+ @app.post("/generate")
37
+ async def generate_text(input: TextInput):
38
  try:
39
+ logger.info(f"Generating text for input: {input.text}")
40
+ inputs = tokenizer(input.text, return_tensors="pt", max_length=512, truncation=True).to("cpu")
41
+ outputs = model.generate(**inputs, max_length=input.max_length)
42
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
  logger.info(f"Generated text: {result}")
44
+ return {"generated_text": result}
45
  except Exception as e:
46
  logger.error(f"Generation error: {e}")
47
+ return {"error": str(e)}"