Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import requests
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
class TextGenerator:
|
9 |
+
def __init__(self, token, model_id):
|
10 |
+
self.token = token
|
11 |
+
self.model_id = model_id
|
12 |
+
|
13 |
+
def generate_text(self, prompt, hf_model=None, hf_token=None, parameters=None):
|
14 |
+
if hf_token is None:
|
15 |
+
hf_token = self.token
|
16 |
+
if hf_model is None:
|
17 |
+
hf_model = self.model_id
|
18 |
+
if parameters is None:
|
19 |
+
parameters = {'max_new_tokens': 512, 'do_sample': True, 'return_full_text': False, 'temperature': 1.0, 'top_k': 50, 'repetition_penalty': 1.2}
|
20 |
+
url = f'https://api-inference.huggingface.co/models/{hf_model}'
|
21 |
+
headers = {'Authorization': f'Bearer {hf_token}', 'Content-type': 'application/json'}
|
22 |
+
data = {'inputs': prompt, 'stream': False, 'options': {'use_cache': False}, 'parameters': parameters}
|
23 |
+
r = requests.post(url, headers=headers, data=json.dumps(data))
|
24 |
+
return json.loads(r.content.decode("utf-8"))[0]['generated_text']
|
25 |
+
|
26 |
+
class GradioUI:
|
27 |
+
def __init__(self, text_generator):
|
28 |
+
self.text_generator = text_generator
|
29 |
+
self.styles = """...""" # Add CSS styles here
|
30 |
+
self.title_placeholders = ['Pokemon training story']
|
31 |
+
|
32 |
+
def random_title(self):
|
33 |
+
return random.choice(self.title_placeholders)
|
34 |
+
|
35 |
+
def generate_interface(self):
|
36 |
+
with gr.Blocks(css=self.styles) as demo:
|
37 |
+
# Gradio UI setup goes here
|
38 |
+
random_title_btn.click(fn=None, inputs=None, outputs=[title], _js=f"return ['{self.random_title()}'];")
|
39 |
+
# Additional Gradio configurations go here...
|
40 |
+
demo.queue(concurrency_count=5, max_size=256)
|
41 |
+
demo.launch()
|
42 |
+
|
43 |
+
token = os.getenv('HF_TOKEN')
|
44 |
+
model_id = 'meta-llama/Llama-2-70b-chat-hf'
|
45 |
+
text_gen = TextGenerator(token, model_id)
|
46 |
+
ui = GradioUI(text_gen)
|
47 |
+
ui.generate_interface()
|