Bradarr commited on
Commit
30b6044
·
1 Parent(s): b495cef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -1
app.py CHANGED
@@ -1,3 +1,75 @@
1
  import gradio as gr
 
 
 
2
 
3
- gr.Interface.load("models/EleutherAI/gpt-j-6B").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
 
6
+ #Import Hugging Face's Transformers
7
+ from transformers import pipeline
8
+ # This is to log our outputs in a nicer format
9
+ from pprint import pprint
10
+
11
+ # from transformers import GPTJForCausalLM
12
+ # import torch
13
+
14
+ # model = GPTJForCausalLM.from_pretrained(
15
+ # "EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
16
+ # )
17
+
18
+ generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B')
19
+
20
+ # from transformers import GPTJForCausalLM, AutoTokenizer
21
+ # import torch
22
+
23
+ # model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16, low_cpu_mem_usage=True)
24
+ # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
25
+
26
+ # prompt = (
27
+ # "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
28
+ # "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
29
+ # "researchers was the fact that the unicorns spoke perfect English."
30
+ # )
31
+
32
+ # input_ids = tokenizer(prompt, return_tensors="pt").input_ids
33
+
34
+ # gen_tokens = model.generate(
35
+ # input_ids,
36
+ # do_sample=True,
37
+ # temperature=0.9,
38
+ # max_length=100,
39
+ # )
40
+ # gen_text = tokenizer.batch_decode(gen_tokens)[0]
41
+
42
+ def run(prompt, max_len, temp):
43
+ min_len = 1
44
+ output = generator(prompt, do_sample=True, min_length=min_len, max_length=max_len, temperature=temp)
45
+ return (output[0]['generated_text'],"")
46
+
47
+ if __name__ == "__main__":
48
+ demo = gr.Blocks()
49
+ with demo:
50
+ with gr.Row():
51
+ gr.Markdown(value=description)
52
+ with gr.Row():
53
+ with gr.Column():
54
+ text = gr.Textbox(
55
+ label="Input",
56
+ value=" ", # should be set to " " when plugged into a real API
57
+ )
58
+ tokens = gr.Slider(1, 250, value=50, step=1, label="Tokens to generate")
59
+ temp = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
60
+
61
+ with gr.Row():
62
+ submit = gr.Button("Submit")
63
+ with gr.Column():
64
+ text_error = gr.Markdown(label="Log information")
65
+ text_out = gr.Textbox(label="Output")
66
+ with gr.Row():
67
+ submit.click(
68
+ run,
69
+ inputs=[text, tokens, temperature],
70
+ outputs=[text_out, text_error],
71
+ )
72
+
73
+ demo.launch()
74
+
75
+ #gr.Interface.load("models/EleutherAI/gpt-j-6B").launch()