File size: 2,536 Bytes
dbe5e76
 
 
624a54f
dbe5e76
 
 
 
624a54f
 
 
 
 
 
dbe5e76
 
624a54f
 
 
 
 
 
 
 
 
 
 
 
 
0bfba0d
 
624a54f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe5e76
 
 
624a54f
 
0bfba0d
624a54f
 
 
0bfba0d
 
624a54f
 
 
dbe5e76
 
 
624a54f
dbe5e76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import json
import random
import string
import requests
import gradio as gr

class GradioUI:
    def __init__(self):
        self.styles = """
        ... (styles omitted for brevity) ...
        """
        self.token = os.getenv('HF_TOKEN')
        self.model_id = 'meta-llama/Llama-2-70b-chat-hf'

    def random_title(self):
        titles = [
            'Pokemon training story',
            'The Sun of Shangri-La',
            'Man In The Future',
            'Friends',
            'Cyborg Of A Beast',
            'Man At The Graveyard',
            'Vampire Of The Land',
            'A software engineer who is looking for job',
            'A software engineer licensed to kill'
        ]
        return random.choice(titles)

    def generate_text(self, title, story):
        prompt = f"{title}. {story}"
        url = f'https://api-inference.huggingface.co/models/{self.model_id}'
        headers = {
            'Authorization': f'Bearer {self.token}',
            'Content-type': 'application/json'
        }
        data = {
            'inputs': prompt,
            'stream': False,
            'options': {
                'use_cache': False,
            },
            'parameters': {
                'max_new_tokens': 512,
                'do_sample': True,
                'return_full_text': False,
                'temperature': 1.0,
                'top_k': 50,
                'repetition_penalty': 1.2
            }
        }

        r = requests.post(url, headers=headers, data=json.dumps(data))
        if r.reason != 'OK':
            raise ValueError("Response other than 200")
        return json.loads(r.content.decode("utf-8"))[0]['generated_text']

    def generate_interface(self):
        with gr.Blocks(css=self.styles) as demo:
            title = gr.Textbox(placeholder="Title", value=self.random_title())
            random_title_btn = gr.Button("Get Random Title")
            random_title_btn.click(fn=self.random_title, inputs=None, outputs=[title])
            editor = gr.Textbox(placeholder="Write your story here.", lines=32, max_lines=32, elem_classes=['no-label', 'small-big-textarea'])
            gen_btn = gr.Button("Generate Text")
            chatbot = gr.Chatbot([])
            
            gen_btn.click(fn=self.generate_text, inputs=[title, editor], outputs=[chatbot])
            
            # More Gradio elements and logic can be added here.

        demo.queue(concurrency_count=5, max_size=256)
        demo.launch()

ui = GradioUI()
ui.generate_interface()