DilipKY commited on
Commit
c7a8ac7
·
verified ·
1 Parent(s): a4bcde0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from datetime import datetime
4
 
5
- # Custom CSS for a Grok/ChatGPT-like look
6
  st.markdown("""
7
  <style>
8
  .main { background-color: #f9f9f9; padding: 20px; }
@@ -57,29 +57,44 @@ st.markdown("""
57
  </style>
58
  """, unsafe_allow_html=True)
59
 
60
- # Load model and tokenizer
61
- try:
 
62
  checkpoint = "Salesforce/codegen-350M-mono"
63
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
64
- model = AutoModelForCausalLM.from_pretrained(checkpoint)
65
- except Exception as e:
66
- st.error(f"Error loading model: {e}")
 
 
 
 
 
 
 
 
 
 
67
  st.stop()
68
 
69
  # Function to generate code
70
  def generate_code(description):
71
  prompt = f"Generate Python code for the following task: {description}\n"
72
  inputs = tokenizer(prompt, return_tensors="pt")
73
- outputs = model.generate(
74
- **inputs,
75
- max_length=500,
76
- num_return_sequences=1,
77
- pad_token_id=tokenizer.eos_token_id
78
- )
79
- code = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
- return code[len(prompt):].strip()
 
 
 
 
81
 
82
- # Initialize chat history in session state
83
  if "chat_history" not in st.session_state:
84
  st.session_state.chat_history = []
85
 
@@ -95,13 +110,12 @@ with st.container():
95
  height=150
96
  )
97
 
98
- col1, col2 = st.columns([1, 1]) # Adjusted for two buttons
99
  with col1:
100
  if st.button("Generate"):
101
  if description.strip():
102
  with st.spinner("Thinking..."):
103
  generated_code = generate_code(description)
104
- # Append to chat history with timestamp
105
  st.session_state.chat_history.append({
106
  "input": description,
107
  "output": generated_code,
@@ -121,7 +135,6 @@ with st.container():
121
  for chat in st.session_state.chat_history:
122
  st.markdown(f'<div class="chat-message"><strong>You ({chat["time"]}):</strong> {chat["input"]}</div>', unsafe_allow_html=True)
123
  st.markdown(f'<div class="code-output">{chat["output"]}</div>', unsafe_allow_html=True)
124
- st.markdown("---") # Separator for readability
125
 
126
- # Optional tip at the bottom
127
  st.info("Tip: Check the generated code for accuracy before using it!")
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from datetime import datetime
4
 
5
+ # Custom CSS for UI
6
  st.markdown("""
7
  <style>
8
  .main { background-color: #f9f9f9; padding: 20px; }
 
57
  </style>
58
  """, unsafe_allow_html=True)
59
 
60
+ # Cache model and tokenizer to avoid reloading
61
+ @st.cache_resource
62
+ def load_model_and_tokenizer():
63
  checkpoint = "Salesforce/codegen-350M-mono"
64
+ try:
65
+ st.write("Loading tokenizer...")
66
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
67
+ st.write("Loading model...")
68
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
69
+ st.write("Model and tokenizer loaded successfully!")
70
+ return tokenizer, model
71
+ except Exception as e:
72
+ st.error(f"Failed to load model/tokenizer: {e}")
73
+ return None, None
74
+
75
+ # Load model and tokenizer once
76
+ tokenizer, model = load_model_and_tokenizer()
77
+ if tokenizer is None or model is None:
78
  st.stop()
79
 
80
  # Function to generate code
81
  def generate_code(description):
82
  prompt = f"Generate Python code for the following task: {description}\n"
83
  inputs = tokenizer(prompt, return_tensors="pt")
84
+ try:
85
+ outputs = model.generate(
86
+ **inputs,
87
+ max_length=500,
88
+ num_return_sequences=1,
89
+ pad_token_id=tokenizer.eos_token_id
90
+ )
91
+ code = tokenizer.decode(outputs[0], skip_special_tokens=True)
92
+ return code[len(prompt):].strip()
93
+ except Exception as e:
94
+ st.error(f"Error generating code: {e}")
95
+ return "Error: Could not generate code."
96
 
97
+ # Initialize chat history
98
  if "chat_history" not in st.session_state:
99
  st.session_state.chat_history = []
100
 
 
110
  height=150
111
  )
112
 
113
+ col1, col2 = st.columns([1, 1])
114
  with col1:
115
  if st.button("Generate"):
116
  if description.strip():
117
  with st.spinner("Thinking..."):
118
  generated_code = generate_code(description)
 
119
  st.session_state.chat_history.append({
120
  "input": description,
121
  "output": generated_code,
 
135
  for chat in st.session_state.chat_history:
136
  st.markdown(f'<div class="chat-message"><strong>You ({chat["time"]}):</strong> {chat["input"]}</div>', unsafe_allow_html=True)
137
  st.markdown(f'<div class="code-output">{chat["output"]}</div>', unsafe_allow_html=True)
138
+ st.markdown("---")
139
 
 
140
  st.info("Tip: Check the generated code for accuracy before using it!")