awacke1 commited on
Commit
d3899c1
·
1 Parent(s): 6822a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -32,13 +32,11 @@ class EmbeddingNN(nn.Module):
32
  x = self.fc(x)
33
  return x
34
 
 
35
  def process_csv(csv_data):
36
  df = pd.read_csv(StringIO(csv_data))
37
-
38
- # Replace LabelEncoder with custom encoding using pandas factorize
39
  df['library_encoded'], library_classes = df['library_name'].factorize()
40
  df['description_encoded'], description_classes = df['description'].factorize()
41
-
42
  return df, library_classes, description_classes
43
 
44
  def train_and_export(df):
@@ -63,13 +61,9 @@ def infer_from_onnx(model_buffer, library_name, library_classes, description_cla
63
  byte_stream = BytesIO(model_buffer.getvalue())
64
  onnx_model = onnx.load(byte_stream)
65
  sess = onnxruntime.InferenceSession(byte_stream.getvalue())
66
-
67
- # Replace transform with custom encoding
68
  encoded_library = torch.tensor([list(library_classes).index(library_name)], dtype=torch.long)
69
-
70
  outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library.numpy()})
71
  predicted_description = description_classes[outputs[0].argmax()]
72
-
73
  return predicted_description
74
 
75
  # Streamlit UI
@@ -103,8 +97,11 @@ numpy,Library for numerical computations in Python
103
  """
104
 
105
  csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
 
 
 
 
106
  if st.button('Convert CSV to ONNX Neural Net'):
107
- df, library_classes, description_classes = process_csv(csv_data)
108
  model_buffer = train_and_export(df)
109
  st.download_button(
110
  label="Download ONNX model",
 
32
  x = self.fc(x)
33
  return x
34
 
35
+ @st.cache
36
  def process_csv(csv_data):
37
  df = pd.read_csv(StringIO(csv_data))
 
 
38
  df['library_encoded'], library_classes = df['library_name'].factorize()
39
  df['description_encoded'], description_classes = df['description'].factorize()
 
40
  return df, library_classes, description_classes
41
 
42
  def train_and_export(df):
 
61
  byte_stream = BytesIO(model_buffer.getvalue())
62
  onnx_model = onnx.load(byte_stream)
63
  sess = onnxruntime.InferenceSession(byte_stream.getvalue())
 
 
64
  encoded_library = torch.tensor([list(library_classes).index(library_name)], dtype=torch.long)
 
65
  outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library.numpy()})
66
  predicted_description = description_classes[outputs[0].argmax()]
 
67
  return predicted_description
68
 
69
  # Streamlit UI
 
97
  """
98
 
99
  csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
100
+
101
+ # Process CSV and cache the results
102
+ df, library_classes, description_classes = process_csv(csv_data)
103
+
104
  if st.button('Convert CSV to ONNX Neural Net'):
 
105
  model_buffer = train_and_export(df)
106
  st.download_button(
107
  label="Download ONNX model",