laserbeam2045 commited on
Commit
861971b
·
1 Parent(s): 468f1f8
Files changed (2) hide show
  1. app.py +15 -19
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
2
  import torch
3
- import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
5
  import logging
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
 
 
10
  model_name = "google/gemma-2-2b-it"
11
  try:
12
  logger.info(f"Loading model: {model_name}")
@@ -32,26 +35,19 @@ except Exception as e:
32
  logger.error(f"Model load error: {e}")
33
  raise
34
 
35
- def generate_text(text, max_length=50):
 
 
 
 
 
36
  try:
37
- logger.info(f"Generating text for input: {text}")
38
- inputs = tokenizer(text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
39
- outputs = model.generate(**inputs, max_length=max_length)
40
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
  logger.info(f"Generated text: {result}")
42
- return result
43
  except Exception as e:
44
  logger.error(f"Generation error: {e}")
45
- return f"Error: {str(e)}"
46
-
47
- iface = gr.Interface(
48
- fn=generate_text,
49
- inputs=[gr.Textbox(label="Input Text"), gr.Slider(10, 100, value=50, label="Max Length")],
50
- outputs=gr.Textbox(label="Generated Text"),
51
- title="Gemma 2 API"
52
- )
53
-
54
- if __name__ == "__main__":
55
- logger.info("Launching Gradio interface")
56
- iface.launch(server_name="0.0.0.0", server_port=8080, share=True)
57
- logger.info("end")
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from pydantic import BaseModel
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ app = FastAPI()
12
+
13
  model_name = "google/gemma-2-2b-it"
14
  try:
15
  logger.info(f"Loading model: {model_name}")
 
35
  logger.error(f"Model load error: {e}")
36
  raise
37
 
38
+ class TextInput(BaseModel):
39
+ text: str
40
+ max_length: int = 50
41
+
42
+ @app.post("/generate")
43
+ async def generate_text(input: TextInput):
44
  try:
45
+ logger.info(f"Generating text for input: {input.text}")
46
+ inputs = tokenizer(input.text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
47
+ outputs = model.generate(**inputs, max_length=input.max_length)
48
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  logger.info(f"Generated text: {result}")
50
+ return {"generated_text": result}
51
  except Exception as e:
52
  logger.error(f"Generation error: {e}")
53
+ return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -6,4 +6,3 @@ bitsandbytes==0.42.0
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6
9
- gradio==4.15.0
 
6
  accelerate==0.26.1
7
  fastapi==0.115.0
8
  uvicorn==0.30.6