vykanand commited on
Commit
8785344
·
1 Parent(s): 8467aa9

Update Gradio configuration to expose API endpoint

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -18,19 +18,33 @@ class GenerationRequest(BaseModel):
18
  early_stopping: bool = True
19
  no_repeat_ngram_size: int = 3
20
 
21
- @app.post("/generate")
22
- async def generate_text(request: GenerationRequest):
23
- inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
24
  outputs = model.generate(
25
  **inputs,
26
- max_length=request.max_length,
27
- num_beams=request.num_beams,
28
- early_stopping=request.early_stopping,
29
- no_repeat_ngram_size=request.no_repeat_ngram_size,
30
  eos_token_id=tokenizer.eos_token_id,
31
  pad_token_id=tokenizer.pad_token_id,
32
  )
33
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return {"generated_text": output_text}
35
 
36
  if __name__ == "__main__":
 
18
  early_stopping: bool = True
19
  no_repeat_ngram_size: int = 3
20
 
21
+ def generate(prompt: str, max_length: int = 2048, num_beams: int = 3, early_stopping: bool = True, no_repeat_ngram_size: int = 3):
22
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
23
  outputs = model.generate(
24
  **inputs,
25
+ max_length=max_length,
26
+ num_beams=num_beams,
27
+ early_stopping=early_stopping,
28
+ no_repeat_ngram_size=no_repeat_ngram_size,
29
  eos_token_id=tokenizer.eos_token_id,
30
  pad_token_id=tokenizer.pad_token_id,
31
  )
32
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+ return output_text
34
+
35
+ iface = gr.Interface(
36
+ fn=generate,
37
+ inputs=gr.Textbox(lines=10, label="Input Prompt"),
38
+ outputs=gr.Textbox(label="Generated Output"),
39
+ title="LLaMA 7B Server",
40
+ description="A web interface for interacting with the LLaMA 7B model.",
41
+ allow_flagging="never",
42
+ api_open=True
43
+ )
44
+
45
+ @app.post("/generate")
46
+ async def generate_text(request: GenerationRequest):
47
+ return {"generated_text": generate(**request.dict())}
48
  return {"generated_text": output_text}
49
 
50
  if __name__ == "__main__":