vykanand commited on
Commit
562f120
·
1 Parent(s): a431f91

Update to use Gradio instead of FastAPI

Browse files
Files changed (3) hide show
  1. app.py +17 -11
  2. requirements.txt +1 -0
  3. start.sh +1 -0
app.py CHANGED
@@ -1,5 +1,4 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
@@ -9,14 +8,8 @@ model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-220m")
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model = model.to(device)
11
 
12
- app = FastAPI()
13
-
14
- class Input(BaseModel):
15
- prompt: str
16
-
17
- @app.post("/generate")
18
- async def generate(input: Input):
19
- inputs = tokenizer(input.prompt, return_tensors="pt").to(device)
20
  outputs = model.generate(
21
  **inputs,
22
  max_length=2048,
@@ -27,4 +20,17 @@ async def generate(input: Input):
27
  pad_token_id=tokenizer.pad_token_id,
28
  )
29
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
- return {"output": output_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  model = model.to(device)
10
 
11
+ def generate(prompt):
12
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
13
  outputs = model.generate(
14
  **inputs,
15
  max_length=2048,
 
20
  pad_token_id=tokenizer.pad_token_id,
21
  )
22
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ return output_text
24
+
25
+ # Create Gradio interface
26
+ iface = gr.Interface(
27
+ fn=generate,
28
+ inputs=gr.Textbox(lines=10, label="Input Prompt"),
29
+ outputs=gr.Textbox(label="Generated Output"),
30
+ title="LLaMA 7B Server",
31
+ description="A web interface for interacting with the LLaMA 7B model."
32
+ )
33
+
34
+ # Launch the interface
35
+ if __name__ == "__main__":
36
+ iface.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -2,3 +2,4 @@ fastapi
2
  uvicorn[standard]
3
  transformers
4
  torch
 
 
2
  uvicorn[standard]
3
  transformers
4
  torch
5
+ gradio>=4.17.0
start.sh CHANGED
@@ -1,2 +1,3 @@
1
  #!/bin/bash
 
2
  uvicorn app:app --host 0.0.0.0 --port 7860
 
1
  #!/bin/bash
2
+ python app.py
3
  uvicorn app:app --host 0.0.0.0 --port 7860