zee2221 commited on
Commit
7eeee22
·
1 Parent(s): 98679f3
Files changed (1) hide show
  1. app.py +151 -1
app.py CHANGED
@@ -1,3 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/bigscience/bloom").launch()
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Solutions
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ huggingface
16
+ /
17
+ bloom_demo Copied
18
+ like
19
+ 240
20
+ App
21
+ Files and versions
22
+ Community
23
+ 16
24
+ bloom_demo
25
+ /
26
+ app.py
27
+ Narsil's picture
28
+ Narsil
29
+ HF STAFF
30
+ Delete queue.
31
+ 55d74b4
32
+ 7 months ago
33
+ raw
34
+ history
35
+ blame
36
+ contribute
37
+ delete
38
+ No virus
39
+ 3.55 kB
40
  import gradio as gr
41
+ import requests
42
+ import json
43
+ import os
44
+ from screenshot import (
45
+ before_prompt,
46
+ prompt_to_generation,
47
+ after_generation,
48
+ js_save,
49
+ js_load_script,
50
+ )
51
+ from spaces_info import description, examples, initial_prompt_value
52
+
53
+ API_URL = os.getenv("API_URL")
54
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
55
+
56
+
57
+ def query(payload):
58
+ print(payload)
59
+ response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
60
+ print(response)
61
+ return json.loads(response.content.decode("utf-8"))
62
+
63
+
64
+ def inference(input_sentence, max_length, sample_or_greedy, seed=42):
65
+ if sample_or_greedy == "Sample":
66
+ parameters = {
67
+ "max_new_tokens": max_length,
68
+ "top_p": 0.9,
69
+ "do_sample": True,
70
+ "seed": seed,
71
+ "early_stopping": False,
72
+ "length_penalty": 0.0,
73
+ "eos_token_id": None,
74
+ }
75
+ else:
76
+ parameters = {
77
+ "max_new_tokens": max_length,
78
+ "do_sample": False,
79
+ "seed": seed,
80
+ "early_stopping": False,
81
+ "length_penalty": 0.0,
82
+ "eos_token_id": None,
83
+ }
84
+
85
+ payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
86
+
87
+ data = query(payload)
88
+
89
+ if "error" in data:
90
+ return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
91
+
92
+ generation = data[0]["generated_text"].split(input_sentence, 1)[1]
93
+ return (
94
+ before_prompt
95
+ + input_sentence
96
+ + prompt_to_generation
97
+ + generation
98
+ + after_generation,
99
+ data[0]["generated_text"],
100
+ "",
101
+ )
102
+
103
+
104
+ if __name__ == "__main__":
105
+ demo = gr.Blocks()
106
+ with demo:
107
+ with gr.Row():
108
+ gr.Markdown(value=description)
109
+ with gr.Row():
110
+ with gr.Column():
111
+ text = gr.Textbox(
112
+ label="Input",
113
+ value=" ", # should be set to " " when plugged into a real API
114
+ )
115
+ tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
116
+ sampling = gr.Radio(
117
+ ["Sample", "Greedy"], label="Sample or greedy", value="Sample"
118
+ )
119
+ sampling2 = gr.Radio(
120
+ ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
121
+ value="Sample 1",
122
+ label="Sample other generations (only work in 'Sample' mode)",
123
+ type="index",
124
+ )
125
+
126
+ with gr.Row():
127
+ submit = gr.Button("Submit")
128
+ load_image = gr.Button("Generate Image")
129
+ with gr.Column():
130
+ text_error = gr.Markdown(label="Log information")
131
+ text_out = gr.Textbox(label="Output")
132
+ display_out = gr.HTML(label="Image")
133
+ display_out.set_event_trigger(
134
+ "load",
135
+ fn=None,
136
+ inputs=None,
137
+ outputs=None,
138
+ no_target=True,
139
+ js=js_load_script,
140
+ )
141
+ with gr.Row():
142
+ gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
143
+
144
+ submit.click(
145
+ inference,
146
+ inputs=[text, tokens, sampling, sampling2],
147
+ outputs=[display_out, text_out, text_error],
148
+ )
149
+
150
+ load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
151
+
152
+ demo.launch()
153