abdullah63 commited on
Commit
332b30f
·
verified ·
1 Parent(s): 2572242

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -24
app.py CHANGED
@@ -134,39 +134,62 @@ class Transformer(nn.Module):
134
 
135
  # Set device
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
137
 
138
  # Load tokenizers
139
- sp_pseudo = spm.SentencePieceProcessor(model_file="pseudo.model") # For decoding pseudocode (target)
140
- sp_code = spm.SentencePieceProcessor(model_file="code.model") # For encoding C++ (source)
141
-
142
- # Load the full saved model (architecture + weights)
143
- model_path = "transformer_cpp_to_pseudo_30.pth" # Adjust path to your C++ to pseudocode model
144
- model = torch.load(model_path, map_location=device, weights_only=False)
145
- model.eval()
146
- model = model.to(device)
 
 
 
 
 
 
 
 
 
 
147
 
148
  def generate_pseudocode(cpp_code, max_len):
149
  """Generate pseudocode from C++ code with streaming output."""
 
150
  model.eval()
151
- src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
- tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
153
 
154
- generated_tokens = [2] # Start with <START>
155
- response = ""
156
- with torch.no_grad():
157
- for _ in range(max_len):
158
- output = model(src, tgt)
159
- next_token = output[:, -1, :].argmax(-1).item()
160
- generated_tokens.append(next_token)
161
- tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
162
- response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
163
- yield response # Yield partial output
164
- if next_token == 3: # <END>=3 (adjust if your EOS ID differs)
165
- break
166
- yield response # Final output
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def respond(message, history, max_tokens):
169
  """Wrapper for Gradio interface."""
 
170
  for response in generate_pseudocode(message, max_tokens):
171
  yield response
172
 
@@ -183,4 +206,4 @@ demo = gr.ChatInterface(
183
  )
184
 
185
  if __name__ == "__main__":
186
- demo.launch()
 
134
 
135
  # Set device
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
+ print(f"Using device: {device}")
138
 
139
  # Load tokenizers
140
+ try:
141
+ sp_pseudo = spm.SentencePieceProcessor(model_file="pseudo.model")
142
+ sp_code = spm.SentencePieceProcessor(model_file="code.model")
143
+ print("Tokenizers loaded successfully.")
144
+ except Exception as e:
145
+ print(f"Error loading tokenizers: {e}")
146
+ raise
147
+
148
+ # Load the full saved model
149
+ model_path = "transformer_cpp_to_pseudo_30.pth"
150
+ try:
151
+ model = torch.load(model_path, map_location=device, weights_only=False)
152
+ model.eval()
153
+ model = model.to(device)
154
+ print("Model loaded successfully.")
155
+ except Exception as e:
156
+ print(f"Error loading model: {e}")
157
+ raise
158
 
159
  def generate_pseudocode(cpp_code, max_len):
160
  """Generate pseudocode from C++ code with streaming output."""
161
+ print(f"Input C++ code: {cpp_code}")
162
  model.eval()
 
 
163
 
164
+ try:
165
+ src_tokens = sp_code.encode_as_ids(cpp_code)
166
+ print(f"Source tokens: {src_tokens}")
167
+ src = torch.tensor([src_tokens], dtype=torch.long, device=device)
168
+
169
+ tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
170
+ generated_tokens = [2] # Start with <START>
171
+ response = ""
172
+
173
+ with torch.no_grad():
174
+ for i in range(max_len):
175
+ output = model(src, tgt)
176
+ next_token = output[:, -1, :].argmax(-1).item()
177
+ generated_tokens.append(next_token)
178
+ tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
179
+ response = sp_pseudo.decode_ids(generated_tokens)
180
+ print(f"Step {i}: Next token = {next_token}, Generated so far: {response}")
181
+ yield response # Yield partial output
182
+ if next_token == 3: # <END>=3
183
+ print("EOS token detected, stopping generation.")
184
+ break
185
+ yield response # Final output
186
+ except Exception as e:
187
+ print(f"Error in generation: {e}")
188
+ yield f"Error: {e}"
189
 
190
  def respond(message, history, max_tokens):
191
  """Wrapper for Gradio interface."""
192
+ print(f"Received message: {message}")
193
  for response in generate_pseudocode(message, max_tokens):
194
  yield response
195
 
 
206
  )
207
 
208
  if __name__ == "__main__":
209
+ demo.launch(debug=True) # Enable debug mode for more output