sonyps1928 commited on
Commit
ad32177
·
1 Parent(s): 66a26f6

Add application file

Browse files
Files changed (2) hide show
  1. app.py +192 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from transformers import (
4
+ GPT2LMHeadModel, GPT2Tokenizer,
5
+ T5ForConditionalGeneration, T5Tokenizer,
6
+ AutoTokenizer, AutoModelForCausalLM
7
+ )
8
+ import torch
9
+
10
+ # Configuration for multiple models, can add more by extending MODEL_CONFIGS dict
11
+ MODEL_CONFIGS = {
12
+ "gpt2": {
13
+ "type": "causal",
14
+ "model_class": GPT2LMHeadModel,
15
+ "tokenizer_class": GPT2Tokenizer,
16
+ "description": "Original GPT-2, good for creative writing",
17
+ "size": "117M"
18
+ },
19
+ "distilgpt2": {
20
+ "type": "causal",
21
+ "model_class": AutoModelForCausalLM,
22
+ "tokenizer_class": AutoTokenizer,
23
+ "description": "Smaller, faster GPT-2",
24
+ "size": "82M"
25
+ },
26
+ "google/flan-t5-small": {
27
+ "type": "seq2seq",
28
+ "model_class": T5ForConditionalGeneration,
29
+ "tokenizer_class": T5Tokenizer,
30
+ "description": "Instruction-following T5 model",
31
+ "size": "80M"
32
+ },
33
+ "microsoft/DialoGPT-small": {
34
+ "type": "causal",
35
+ "model_class": AutoModelForCausalLM,
36
+ "tokenizer_class": AutoTokenizer,
37
+ "description": "Conversational AI model",
38
+ "size": "117M"
39
+ }
40
+ }
41
+
42
+ # Environment variables for optional authentication and private model access
43
+ HF_TOKEN = os.getenv("HF_TOKEN")
44
+ API_KEY = os.getenv("API_KEY")
45
+ ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
46
+
47
+ # Global state for caching loaded model and tokenizer
48
+ loaded_model_name = None
49
+ model = None
50
+ tokenizer = None
51
+
52
+ def load_model_and_tokenizer(model_name):
53
+ global loaded_model_name, model, tokenizer
54
+ if model_name == loaded_model_name and model is not None and tokenizer is not None:
55
+ return model, tokenizer
56
+
57
+ config = MODEL_CONFIGS[model_name]
58
+ if HF_TOKEN:
59
+ tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
60
+ model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
61
+ else:
62
+ tokenizer = config["tokenizer_class"].from_pretrained(model_name)
63
+ model = config["model_class"].from_pretrained(model_name)
64
+
65
+ # Set pad token for causal models if missing (important for generation padding)
66
+ if config["type"] == "causal" and tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ loaded_model_name = model_name
70
+ return model, tokenizer
71
+
72
+ def authenticate_api_key(key):
73
+ if API_KEY and key != API_KEY:
74
+ return False
75
+ return True
76
+
77
+ def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
78
+ if API_KEY and not authenticate_api_key(api_key):
79
+ return "Error: Invalid API key"
80
+
81
+ try:
82
+ config = MODEL_CONFIGS[model_name]
83
+ model, tokenizer = load_model_and_tokenizer(model_name)
84
+
85
+ if config["type"] == "causal":
86
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
87
+ with torch.no_grad():
88
+ outputs = model.generate(
89
+ inputs,
90
+ max_length=min(max_length + inputs.shape[1], 512),
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ top_k=top_k,
94
+ do_sample=True,
95
+ pad_token_id=tokenizer.pad_token_id,
96
+ num_return_sequences=1
97
+ )
98
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+ # Return generated continuation (remove original prompt)
100
+ return generated_text[len(prompt):].strip()
101
+
102
+ elif config["type"] == "seq2seq":
103
+ # Add task prefix for certain seq2seq models like flan-t5
104
+ task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
105
+ inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
106
+ with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_length=max_length,
110
+ temperature=temperature,
111
+ top_p=top_p,
112
+ top_k=top_k,
113
+ do_sample=True,
114
+ num_return_sequences=1
115
+ )
116
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ return generated_text.strip()
118
+
119
+ except Exception as e:
120
+ return f"Error generating text: {str(e)}"
121
+
122
+ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
123
+ gr.Markdown("# Multi-Model Text Generation Server")
124
+ gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ model_selector = gr.Dropdown(
129
+ label="Model",
130
+ choices=list(MODEL_CONFIGS.keys()),
131
+ value="gpt2",
132
+ interactive=True
133
+ )
134
+ prompt_input = gr.Textbox(
135
+ label="Text Prompt",
136
+ placeholder="Enter the text prompt here...",
137
+ lines=4
138
+ )
139
+ max_length_slider = gr.Slider(
140
+ 10, 200, 100, 10,
141
+ label="Max Generation Length"
142
+ )
143
+ temperature_slider = gr.Slider(
144
+ 0.1, 2.0, 0.7, 0.1,
145
+ label="Temperature"
146
+ )
147
+ top_p_slider = gr.Slider(
148
+ 0.1, 1.0, 0.9, 0.05,
149
+ label="Top-p (nucleus sampling)"
150
+ )
151
+ top_k_slider = gr.Slider(
152
+ 1, 100, 50, 1,
153
+ label="Top-k sampling"
154
+ )
155
+ if API_KEY:
156
+ api_key_input = gr.Textbox(
157
+ label="API Key",
158
+ type="password",
159
+ placeholder="Enter API Key"
160
+ )
161
+ else:
162
+ api_key_input = gr.Textbox(value="", visible=False)
163
+
164
+ generate_btn = gr.Button("Generate Text", variant="primary")
165
+
166
+ with gr.Column():
167
+ output_textbox = gr.Textbox(
168
+ label="Generated Text",
169
+ lines=10,
170
+ placeholder="Generated text will appear here..."
171
+ )
172
+
173
+ generate_btn.click(
174
+ fn=generate_text,
175
+ inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
176
+ outputs=output_textbox
177
+ )
178
+
179
+ gr.Examples(
180
+ examples=[
181
+ ["Once upon a time in a distant galaxy,"],
182
+ ["The future of artificial intelligence is"],
183
+ ["In the heart of the ancient forest,"],
184
+ ["The detective walked into the room and noticed"],
185
+ ],
186
+ inputs=prompt_input
187
+ )
188
+
189
+ auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
190
+
191
+ if __name__ == "__main__":
192
+ demo.launch(auth=auth_config)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=3.50.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ tokenizers>=0.13.0