vykanand commited on
Commit
ef12370
·
1 Parent(s): 8785344

Remove Gradio, switch to FastAPI implementation

Browse files
Files changed (2) hide show
  1. README.md +50 -3
  2. app.py +7 -21
README.md CHANGED
@@ -3,9 +3,56 @@ title: LLaMA 7B Server
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: "4.17.0"
8
- app_file: app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
 
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
6
+ # LLaMA 7B Server
7
+
8
+ A FastAPI-based server for interacting with the LLaMA 7B model.
9
+
10
+ ## Features
11
+
12
+ - [x] Text generation
13
+ - [x] Model parameters configuration
14
+ - [x] REST API interface
15
+
16
+ ## API Usage
17
+
18
+ Make a POST request to `/generate` with the following JSON body:
19
+
20
+ ```json
21
+ {
22
+ "prompt": "your prompt here",
23
+ "max_length": 2048,
24
+ "num_beams": 3,
25
+ "early_stopping": true,
26
+ "no_repeat_ngram_size": 3
27
+ }
28
+ ```
29
+
30
+ Example using curl:
31
+
32
+ ```bash
33
+ curl -X POST http://localhost:7860/generate \
34
+ -H "Content-Type: application/json" \
35
+ -d '{"prompt": "Hello, how are you?"}'
36
+ ```
37
+
38
+ Example using Python:
39
+
40
+ ```python
41
+ import requests
42
+
43
+ url = "http://localhost:7860/generate"
44
+ data = {
45
+ "prompt": "Hello, how are you?",
46
+ "max_length": 2048,
47
+ "num_beams": 3,
48
+ "early_stopping": True,
49
+ "no_repeat_ngram_size": 3
50
+ }
51
+
52
+ response = requests.post(url, json=data)
53
+ result = response.json()
54
+ print(result["generated_text"]) # This will contain your generated text
55
+ ```
56
  pinned: false
57
  ---
58
 
app.py CHANGED
@@ -18,33 +18,19 @@ class GenerationRequest(BaseModel):
18
  early_stopping: bool = True
19
  no_repeat_ngram_size: int = 3
20
 
21
- def generate(prompt: str, max_length: int = 2048, num_beams: int = 3, early_stopping: bool = True, no_repeat_ngram_size: int = 3):
22
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
23
  outputs = model.generate(
24
  **inputs,
25
- max_length=max_length,
26
- num_beams=num_beams,
27
- early_stopping=early_stopping,
28
- no_repeat_ngram_size=no_repeat_ngram_size,
29
  eos_token_id=tokenizer.eos_token_id,
30
  pad_token_id=tokenizer.pad_token_id,
31
  )
32
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
- return output_text
34
-
35
- iface = gr.Interface(
36
- fn=generate,
37
- inputs=gr.Textbox(lines=10, label="Input Prompt"),
38
- outputs=gr.Textbox(label="Generated Output"),
39
- title="LLaMA 7B Server",
40
- description="A web interface for interacting with the LLaMA 7B model.",
41
- allow_flagging="never",
42
- api_open=True
43
- )
44
-
45
- @app.post("/generate")
46
- async def generate_text(request: GenerationRequest):
47
- return {"generated_text": generate(**request.dict())}
48
  return {"generated_text": output_text}
49
 
50
  if __name__ == "__main__":
 
18
  early_stopping: bool = True
19
  no_repeat_ngram_size: int = 3
20
 
21
+ @app.post("/generate")
22
+ async def generate_text(request: GenerationRequest):
23
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
24
  outputs = model.generate(
25
  **inputs,
26
+ max_length=request.max_length,
27
+ num_beams=request.num_beams,
28
+ early_stopping=request.early_stopping,
29
+ no_repeat_ngram_size=request.no_repeat_ngram_size,
30
  eos_token_id=tokenizer.eos_token_id,
31
  pad_token_id=tokenizer.pad_token_id,
32
  )
33
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return {"generated_text": output_text}
35
 
36
  if __name__ == "__main__":