awacke1 commited on
Commit
39a28b1
·
1 Parent(s): 8ceff94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -7,7 +7,6 @@ import torch.optim as optim
7
  import onnx
8
  import onnxruntime
9
  import pandas as pd
10
- from sklearn.preprocessing import LabelEncoder
11
  from io import BytesIO
12
 
13
  # Define a simple neural network
@@ -35,11 +34,12 @@ class EmbeddingNN(nn.Module):
35
 
36
  def process_csv(csv_data):
37
  df = pd.read_csv(StringIO(csv_data))
38
- library_encoder = LabelEncoder()
39
- description_encoder = LabelEncoder()
40
- df['library_encoded'] = library_encoder.fit_transform(df['library_name'])
41
- df['description_encoded'] = description_encoder.fit_transform(df['description'])
42
- return df, library_encoder, description_encoder
 
43
 
44
  def train_and_export(df):
45
  model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
@@ -59,13 +59,17 @@ def train_and_export(df):
59
  torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
60
  return buffer
61
 
62
- def infer_from_onnx(model_buffer, library_name, library_encoder, description_encoder):
63
  byte_stream = BytesIO(model_buffer.getvalue())
64
  onnx_model = onnx.load(byte_stream)
65
  sess = onnxruntime.InferenceSession(byte_stream.getvalue())
66
- encoded_library = library_encoder.transform([library_name])
67
- outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library})
68
- predicted_description = description_encoder.inverse_transform([outputs[0].argmax()])[0]
 
 
 
 
69
  return predicted_description
70
 
71
  # Streamlit UI
@@ -96,12 +100,11 @@ torch,PyTorch is an open-source machine learning library
96
  tensorflow,Open source software library for high performance numerical computations
97
  pandas,Data analysis and manipulation tool
98
  numpy,Library for numerical computations in Python
99
- scikit-learn,Machine learning library in Python
100
  """
101
 
102
  csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
103
  if st.button('Convert CSV to ONNX Neural Net'):
104
- df, library_encoder, description_encoder = process_csv(csv_data)
105
  model_buffer = train_and_export(df)
106
  st.download_button(
107
  label="Download ONNX model",
@@ -114,5 +117,5 @@ if st.button('Convert CSV to ONNX Neural Net'):
114
  uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
115
  library_name_to_infer = st.text_input("Enter a library name for inference:")
116
  if uploaded_model and library_name_to_infer:
117
- prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_encoder, description_encoder)
118
  st.write(f"Predicted description: {prediction}")
 
7
  import onnx
8
  import onnxruntime
9
  import pandas as pd
 
10
  from io import BytesIO
11
 
12
  # Define a simple neural network
 
34
 
35
  def process_csv(csv_data):
36
  df = pd.read_csv(StringIO(csv_data))
37
+
38
+ # Replace LabelEncoder with custom encoding using pandas factorize
39
+ df['library_encoded'], library_classes = df['library_name'].factorize()
40
+ df['description_encoded'], description_classes = df['description'].factorize()
41
+
42
+ return df, library_classes, description_classes
43
 
44
  def train_and_export(df):
45
  model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
 
59
  torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
60
  return buffer
61
 
62
+ def infer_from_onnx(model_buffer, library_name, library_classes, description_classes):
63
  byte_stream = BytesIO(model_buffer.getvalue())
64
  onnx_model = onnx.load(byte_stream)
65
  sess = onnxruntime.InferenceSession(byte_stream.getvalue())
66
+
67
+ # Replace transform with custom encoding
68
+ encoded_library = torch.tensor([list(library_classes).index(library_name)], dtype=torch.long)
69
+
70
+ outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library.numpy()})
71
+ predicted_description = description_classes[outputs[0].argmax()]
72
+
73
  return predicted_description
74
 
75
  # Streamlit UI
 
100
  tensorflow,Open source software library for high performance numerical computations
101
  pandas,Data analysis and manipulation tool
102
  numpy,Library for numerical computations in Python
 
103
  """
104
 
105
  csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
106
  if st.button('Convert CSV to ONNX Neural Net'):
107
+ df, library_classes, description_classes = process_csv(csv_data)
108
  model_buffer = train_and_export(df)
109
  st.download_button(
110
  label="Download ONNX model",
 
117
  uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
118
  library_name_to_infer = st.text_input("Enter a library name for inference:")
119
  if uploaded_model and library_name_to_infer:
120
+ prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_classes, description_classes)
121
  st.write(f"Predicted description: {prediction}")