Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import torch.optim as optim
|
|
7 |
import onnx
|
8 |
import onnxruntime
|
9 |
import pandas as pd
|
10 |
-
from sklearn.preprocessing import LabelEncoder
|
11 |
from io import BytesIO
|
12 |
|
13 |
# Define a simple neural network
|
@@ -35,11 +34,12 @@ class EmbeddingNN(nn.Module):
|
|
35 |
|
36 |
def process_csv(csv_data):
|
37 |
df = pd.read_csv(StringIO(csv_data))
|
38 |
-
|
39 |
-
|
40 |
-
df['library_encoded'] =
|
41 |
-
df['description_encoded'] =
|
42 |
-
|
|
|
43 |
|
44 |
def train_and_export(df):
|
45 |
model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
|
@@ -59,13 +59,17 @@ def train_and_export(df):
|
|
59 |
torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
|
60 |
return buffer
|
61 |
|
62 |
-
def infer_from_onnx(model_buffer, library_name,
|
63 |
byte_stream = BytesIO(model_buffer.getvalue())
|
64 |
onnx_model = onnx.load(byte_stream)
|
65 |
sess = onnxruntime.InferenceSession(byte_stream.getvalue())
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
return predicted_description
|
70 |
|
71 |
# Streamlit UI
|
@@ -96,12 +100,11 @@ torch,PyTorch is an open-source machine learning library
|
|
96 |
tensorflow,Open source software library for high performance numerical computations
|
97 |
pandas,Data analysis and manipulation tool
|
98 |
numpy,Library for numerical computations in Python
|
99 |
-
scikit-learn,Machine learning library in Python
|
100 |
"""
|
101 |
|
102 |
csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
|
103 |
if st.button('Convert CSV to ONNX Neural Net'):
|
104 |
-
df,
|
105 |
model_buffer = train_and_export(df)
|
106 |
st.download_button(
|
107 |
label="Download ONNX model",
|
@@ -114,5 +117,5 @@ if st.button('Convert CSV to ONNX Neural Net'):
|
|
114 |
uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
|
115 |
library_name_to_infer = st.text_input("Enter a library name for inference:")
|
116 |
if uploaded_model and library_name_to_infer:
|
117 |
-
prediction = infer_from_onnx(uploaded_model, library_name_to_infer,
|
118 |
st.write(f"Predicted description: {prediction}")
|
|
|
7 |
import onnx
|
8 |
import onnxruntime
|
9 |
import pandas as pd
|
|
|
10 |
from io import BytesIO
|
11 |
|
12 |
# Define a simple neural network
|
|
|
34 |
|
35 |
def process_csv(csv_data):
|
36 |
df = pd.read_csv(StringIO(csv_data))
|
37 |
+
|
38 |
+
# Replace LabelEncoder with custom encoding using pandas factorize
|
39 |
+
df['library_encoded'], library_classes = df['library_name'].factorize()
|
40 |
+
df['description_encoded'], description_classes = df['description'].factorize()
|
41 |
+
|
42 |
+
return df, library_classes, description_classes
|
43 |
|
44 |
def train_and_export(df):
|
45 |
model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
|
|
|
59 |
torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
|
60 |
return buffer
|
61 |
|
62 |
+
def infer_from_onnx(model_buffer, library_name, library_classes, description_classes):
|
63 |
byte_stream = BytesIO(model_buffer.getvalue())
|
64 |
onnx_model = onnx.load(byte_stream)
|
65 |
sess = onnxruntime.InferenceSession(byte_stream.getvalue())
|
66 |
+
|
67 |
+
# Replace transform with custom encoding
|
68 |
+
encoded_library = torch.tensor([list(library_classes).index(library_name)], dtype=torch.long)
|
69 |
+
|
70 |
+
outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library.numpy()})
|
71 |
+
predicted_description = description_classes[outputs[0].argmax()]
|
72 |
+
|
73 |
return predicted_description
|
74 |
|
75 |
# Streamlit UI
|
|
|
100 |
tensorflow,Open source software library for high performance numerical computations
|
101 |
pandas,Data analysis and manipulation tool
|
102 |
numpy,Library for numerical computations in Python
|
|
|
103 |
"""
|
104 |
|
105 |
csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
|
106 |
if st.button('Convert CSV to ONNX Neural Net'):
|
107 |
+
df, library_classes, description_classes = process_csv(csv_data)
|
108 |
model_buffer = train_and_export(df)
|
109 |
st.download_button(
|
110 |
label="Download ONNX model",
|
|
|
117 |
uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
|
118 |
library_name_to_infer = st.text_input("Enter a library name for inference:")
|
119 |
if uploaded_model and library_name_to_infer:
|
120 |
+
prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_classes, description_classes)
|
121 |
st.write(f"Predicted description: {prediction}")
|