sonyps1928 commited on
Commit
adb694f
Β·
1 Parent(s): 2ce4afd

update app6

Browse files
Files changed (1) hide show
  1. app.py +44 -53
app.py CHANGED
@@ -3,18 +3,25 @@ import os
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
 
6
- # Set page config
 
 
7
  st.set_page_config(
8
  page_title="GPT-2 Text Generator",
9
  page_icon="πŸ€–",
10
  layout="wide"
11
  )
12
 
 
13
  # Load environment variables
 
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
- API_KEY = os.getenv("API_KEY")
16
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
17
 
 
 
 
18
  @st.cache_resource
19
  def load_model():
20
  """Load and cache the GPT-2 model"""
@@ -28,19 +35,20 @@ def load_model():
28
  st.error(f"Error loading model: {e}")
29
  return None, None
30
 
 
 
 
31
  def generate_text(prompt, max_length, temperature, tokenizer, model):
32
  """Generate text using GPT-2"""
33
  if not prompt:
34
  return "Please enter a prompt"
35
-
36
  if len(prompt) > 500:
37
  return "Prompt too long (max 500 characters)"
38
-
39
  try:
40
- # Encode the prompt
41
  inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
42
-
43
- # Generate text
44
  with torch.no_grad():
45
  outputs = model.generate(
46
  inputs,
@@ -51,114 +59,97 @@ def generate_text(prompt, max_length, temperature, tokenizer, model):
51
  eos_token_id=tokenizer.eos_token_id,
52
  no_repeat_ngram_size=2
53
  )
54
-
55
- # Decode the output
56
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
  new_text = generated_text[len(prompt):].strip()
58
-
59
  return new_text if new_text else "No text generated. Try a different prompt."
60
-
61
  except Exception as e:
62
  return f"Error generating text: {str(e)}"
63
 
 
 
 
64
  def check_auth():
65
  """Handle authentication"""
66
  if ADMIN_PASSWORD:
67
  if "authenticated" not in st.session_state:
68
  st.session_state.authenticated = False
69
-
70
  if not st.session_state.authenticated:
71
  st.title("πŸ”’ Authentication Required")
72
  password = st.text_input("Enter admin password:", type="password")
73
  if st.button("Login"):
74
  if password == ADMIN_PASSWORD:
75
  st.session_state.authenticated = True
76
- st.rerun()
77
  else:
78
  st.error("Invalid password")
79
  return False
80
  return True
81
 
 
 
 
82
  def main():
83
- # Check authentication
84
  if not check_auth():
85
  return
86
-
87
- # Load model
88
  tokenizer, model = load_model()
89
  if tokenizer is None or model is None:
90
  st.error("Failed to load model. Please check the logs.")
91
  return
92
-
93
- # Main interface
94
  st.title("πŸ€– GPT-2 Text Generator")
95
  st.markdown("Generate text using GPT-2 language model")
96
-
97
  # Security status
98
  col1, col2, col3 = st.columns(3)
99
  with col1:
100
- if HF_TOKEN:
101
- st.success("πŸ”‘ HF Token: Active")
102
- else:
103
- st.info("πŸ”‘ HF Token: Not set")
104
-
105
  with col2:
106
- if API_KEY:
107
- st.success("πŸ”’ API Auth: Enabled")
108
- else:
109
- st.info("πŸ”’ API Auth: Disabled")
110
-
111
  with col3:
112
- if ADMIN_PASSWORD:
113
- st.success("πŸ‘€ Admin Auth: Active")
114
- else:
115
- st.info("πŸ‘€ Admin Auth: Disabled")
116
-
117
  # Input section
118
  st.subheader("πŸ“ Input")
119
-
120
  col1, col2 = st.columns([2, 1])
121
-
122
  with col1:
123
  prompt = st.text_area(
124
  "Enter your prompt:",
125
  placeholder="Type your text here...",
126
  height=100
127
  )
128
-
129
- # API key input if needed
130
  api_key = ""
131
  if API_KEY:
132
  api_key = st.text_input("API Key:", type="password")
133
-
134
  with col2:
135
  st.subheader("βš™οΈ Settings")
136
  max_length = st.slider("Max Length", 20, 200, 100, 10)
137
  temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
138
-
139
  generate_btn = st.button("πŸš€ Generate Text", type="primary")
140
-
141
  # API key validation
142
  if API_KEY and generate_btn:
143
  if not api_key or api_key != API_KEY:
144
  st.error("πŸ”’ Invalid or missing API key")
145
  return
146
-
147
  # Generate text
148
  if generate_btn and prompt:
149
  with st.spinner("Generating text..."):
150
  result = generate_text(prompt, max_length, temperature, tokenizer, model)
151
-
152
  st.subheader("πŸ“„ Generated Text")
153
  st.text_area("Output:", value=result, height=200)
154
-
155
- # Copy button
156
  st.code(result)
157
-
158
  elif generate_btn:
159
  st.warning("Please enter a prompt")
160
-
161
- # Examples
162
  st.subheader("πŸ’‘ Example Prompts")
163
  examples = [
164
  "Once upon a time in a distant galaxy,",
@@ -166,17 +157,17 @@ def main():
166
  "In the heart of the ancient forest,",
167
  "The detective walked into the room and noticed"
168
  ]
169
-
170
  cols = st.columns(len(examples))
171
  for i, example in enumerate(examples):
172
  with cols[i]:
173
  if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
174
  st.session_state.example_prompt = example
175
- st.rerun()
176
-
177
- # Use selected example
178
  if hasattr(st.session_state, 'example_prompt'):
179
  st.info(f"Example selected: {st.session_state.example_prompt}")
180
 
 
181
  if __name__ == "__main__":
182
- main()
 
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
 
6
+ # ----------------------------
7
+ # Page config
8
+ # ----------------------------
9
  st.set_page_config(
10
  page_title="GPT-2 Text Generator",
11
  page_icon="πŸ€–",
12
  layout="wide"
13
  )
14
 
15
+ # ----------------------------
16
  # Load environment variables
17
+ # ----------------------------
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
+ API_KEY = os.getenv("API_KEY")
20
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
21
 
22
+ # ----------------------------
23
+ # Model loading
24
+ # ----------------------------
25
  @st.cache_resource
26
  def load_model():
27
  """Load and cache the GPT-2 model"""
 
35
  st.error(f"Error loading model: {e}")
36
  return None, None
37
 
38
+ # ----------------------------
39
+ # Text generation
40
+ # ----------------------------
41
  def generate_text(prompt, max_length, temperature, tokenizer, model):
42
  """Generate text using GPT-2"""
43
  if not prompt:
44
  return "Please enter a prompt"
45
+
46
  if len(prompt) > 500:
47
  return "Prompt too long (max 500 characters)"
48
+
49
  try:
 
50
  inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
51
+
 
52
  with torch.no_grad():
53
  outputs = model.generate(
54
  inputs,
 
59
  eos_token_id=tokenizer.eos_token_id,
60
  no_repeat_ngram_size=2
61
  )
62
+
 
63
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  new_text = generated_text[len(prompt):].strip()
65
+
66
  return new_text if new_text else "No text generated. Try a different prompt."
67
+
68
  except Exception as e:
69
  return f"Error generating text: {str(e)}"
70
 
71
+ # ----------------------------
72
+ # Authentication
73
+ # ----------------------------
74
  def check_auth():
75
  """Handle authentication"""
76
  if ADMIN_PASSWORD:
77
  if "authenticated" not in st.session_state:
78
  st.session_state.authenticated = False
79
+
80
  if not st.session_state.authenticated:
81
  st.title("πŸ”’ Authentication Required")
82
  password = st.text_input("Enter admin password:", type="password")
83
  if st.button("Login"):
84
  if password == ADMIN_PASSWORD:
85
  st.session_state.authenticated = True
86
+ st.experimental_rerun()
87
  else:
88
  st.error("Invalid password")
89
  return False
90
  return True
91
 
92
+ # ----------------------------
93
+ # Main UI
94
+ # ----------------------------
95
  def main():
 
96
  if not check_auth():
97
  return
98
+
 
99
  tokenizer, model = load_model()
100
  if tokenizer is None or model is None:
101
  st.error("Failed to load model. Please check the logs.")
102
  return
103
+
 
104
  st.title("πŸ€– GPT-2 Text Generator")
105
  st.markdown("Generate text using GPT-2 language model")
106
+
107
  # Security status
108
  col1, col2, col3 = st.columns(3)
109
  with col1:
110
+ st.success("πŸ”‘ HF Token: Active" if HF_TOKEN else "πŸ”‘ HF Token: Not set")
 
 
 
 
111
  with col2:
112
+ st.success("πŸ”’ API Auth: Enabled" if API_KEY else "πŸ”’ API Auth: Disabled")
 
 
 
 
113
  with col3:
114
+ st.success("πŸ‘€ Admin Auth: Active" if ADMIN_PASSWORD else "πŸ‘€ Admin Auth: Disabled")
115
+
 
 
 
116
  # Input section
117
  st.subheader("πŸ“ Input")
 
118
  col1, col2 = st.columns([2, 1])
119
+
120
  with col1:
121
  prompt = st.text_area(
122
  "Enter your prompt:",
123
  placeholder="Type your text here...",
124
  height=100
125
  )
 
 
126
  api_key = ""
127
  if API_KEY:
128
  api_key = st.text_input("API Key:", type="password")
129
+
130
  with col2:
131
  st.subheader("βš™οΈ Settings")
132
  max_length = st.slider("Max Length", 20, 200, 100, 10)
133
  temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
 
134
  generate_btn = st.button("πŸš€ Generate Text", type="primary")
135
+
136
  # API key validation
137
  if API_KEY and generate_btn:
138
  if not api_key or api_key != API_KEY:
139
  st.error("πŸ”’ Invalid or missing API key")
140
  return
141
+
142
  # Generate text
143
  if generate_btn and prompt:
144
  with st.spinner("Generating text..."):
145
  result = generate_text(prompt, max_length, temperature, tokenizer, model)
 
146
  st.subheader("πŸ“„ Generated Text")
147
  st.text_area("Output:", value=result, height=200)
 
 
148
  st.code(result)
 
149
  elif generate_btn:
150
  st.warning("Please enter a prompt")
151
+
152
+ # Example prompts
153
  st.subheader("πŸ’‘ Example Prompts")
154
  examples = [
155
  "Once upon a time in a distant galaxy,",
 
157
  "In the heart of the ancient forest,",
158
  "The detective walked into the room and noticed"
159
  ]
160
+
161
  cols = st.columns(len(examples))
162
  for i, example in enumerate(examples):
163
  with cols[i]:
164
  if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
165
  st.session_state.example_prompt = example
166
+ st.experimental_rerun()
167
+
 
168
  if hasattr(st.session_state, 'example_prompt'):
169
  st.info(f"Example selected: {st.session_state.example_prompt}")
170
 
171
+
172
  if __name__ == "__main__":
173
+ main()