Spaces:
Runtime error
Runtime error
Update Gradio configuration to expose API endpoint
Browse files
app.py
CHANGED
@@ -18,19 +18,33 @@ class GenerationRequest(BaseModel):
|
|
18 |
early_stopping: bool = True
|
19 |
no_repeat_ngram_size: int = 3
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
|
24 |
outputs = model.generate(
|
25 |
**inputs,
|
26 |
-
max_length=
|
27 |
-
num_beams=
|
28 |
-
early_stopping=
|
29 |
-
no_repeat_ngram_size=
|
30 |
eos_token_id=tokenizer.eos_token_id,
|
31 |
pad_token_id=tokenizer.pad_token_id,
|
32 |
)
|
33 |
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return {"generated_text": output_text}
|
35 |
|
36 |
if __name__ == "__main__":
|
|
|
18 |
early_stopping: bool = True
|
19 |
no_repeat_ngram_size: int = 3
|
20 |
|
21 |
+
def generate(prompt: str, max_length: int = 2048, num_beams: int = 3, early_stopping: bool = True, no_repeat_ngram_size: int = 3):
|
22 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
23 |
outputs = model.generate(
|
24 |
**inputs,
|
25 |
+
max_length=max_length,
|
26 |
+
num_beams=num_beams,
|
27 |
+
early_stopping=early_stopping,
|
28 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
29 |
eos_token_id=tokenizer.eos_token_id,
|
30 |
pad_token_id=tokenizer.pad_token_id,
|
31 |
)
|
32 |
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
33 |
+
return output_text
|
34 |
+
|
35 |
+
iface = gr.Interface(
|
36 |
+
fn=generate,
|
37 |
+
inputs=gr.Textbox(lines=10, label="Input Prompt"),
|
38 |
+
outputs=gr.Textbox(label="Generated Output"),
|
39 |
+
title="LLaMA 7B Server",
|
40 |
+
description="A web interface for interacting with the LLaMA 7B model.",
|
41 |
+
allow_flagging="never",
|
42 |
+
api_open=True
|
43 |
+
)
|
44 |
+
|
45 |
+
@app.post("/generate")
|
46 |
+
async def generate_text(request: GenerationRequest):
|
47 |
+
return {"generated_text": generate(**request.dict())}
|
48 |
return {"generated_text": output_text}
|
49 |
|
50 |
if __name__ == "__main__":
|