Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,84 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
-
import torchvision.transforms as transforms
|
6 |
-
from PIL import Image
|
7 |
import onnx
|
|
|
|
|
|
|
8 |
from io import BytesIO
|
9 |
|
|
|
10 |
class SimpleNN(nn.Module):
|
11 |
def __init__(self):
|
12 |
super(SimpleNN, self).__init__()
|
13 |
-
self.fc = nn.Linear(28 * 28, 10)
|
14 |
|
15 |
def forward(self, x):
|
16 |
x = x.view(-1, 28 * 28)
|
17 |
x = self.fc(x)
|
18 |
return x
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
st.title("PyTorch Neural Network Interface")
|
21 |
|
|
|
22 |
uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
|
23 |
-
|
24 |
if uploaded_file:
|
25 |
byte_stream = BytesIO(uploaded_file.getvalue())
|
26 |
model = onnx.load(byte_stream)
|
27 |
st.write("Model uploaded successfully!")
|
28 |
|
|
|
29 |
if st.button('Download Model as ONNX'):
|
30 |
buffer = BytesIO()
|
31 |
torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
|
@@ -36,5 +89,30 @@ if st.button('Download Model as ONNX'):
|
|
36 |
mime="application/octet-stream"
|
37 |
)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
|
3 |
import streamlit as st
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
|
|
|
|
7 |
import onnx
|
8 |
+
import onnxruntime
|
9 |
+
import pandas as pd
|
10 |
+
from sklearn.preprocessing import LabelEncoder
|
11 |
from io import BytesIO
|
12 |
|
13 |
+
# Define a simple neural network
|
14 |
class SimpleNN(nn.Module):
|
15 |
def __init__(self):
|
16 |
super(SimpleNN, self).__init__()
|
17 |
+
self.fc = nn.Linear(28 * 28, 10)
|
18 |
|
19 |
def forward(self, x):
|
20 |
x = x.view(-1, 28 * 28)
|
21 |
x = self.fc(x)
|
22 |
return x
|
23 |
|
24 |
+
# Neural network for the CSV data
|
25 |
+
class EmbeddingNN(nn.Module):
|
26 |
+
def __init__(self, num_libraries, num_descriptions, embedding_dim=10):
|
27 |
+
super(EmbeddingNN, self).__init__()
|
28 |
+
self.embedding = nn.Embedding(num_libraries, embedding_dim)
|
29 |
+
self.fc = nn.Linear(embedding_dim, num_descriptions)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.embedding(x)
|
33 |
+
x = self.fc(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
def process_csv(csv_data):
|
37 |
+
df = pd.read_csv(StringIO(csv_data))
|
38 |
+
library_encoder = LabelEncoder()
|
39 |
+
description_encoder = LabelEncoder()
|
40 |
+
df['library_encoded'] = library_encoder.fit_transform(df['library_name'])
|
41 |
+
df['description_encoded'] = description_encoder.fit_transform(df['description'])
|
42 |
+
return df, library_encoder, description_encoder
|
43 |
+
|
44 |
+
def train_and_export(df):
|
45 |
+
model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
|
46 |
+
criterion = nn.CrossEntropyLoss()
|
47 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
48 |
+
|
49 |
+
for epoch in range(50):
|
50 |
+
inputs = torch.tensor(df['library_encoded'].values, dtype=torch.long)
|
51 |
+
labels = torch.tensor(df['description_encoded'].values, dtype=torch.long)
|
52 |
+
optimizer.zero_grad()
|
53 |
+
outputs = model(inputs)
|
54 |
+
loss = criterion(outputs, labels)
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
|
58 |
+
buffer = BytesIO()
|
59 |
+
torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
|
60 |
+
return buffer
|
61 |
+
|
62 |
+
def infer_from_onnx(model_buffer, library_name, library_encoder, description_encoder):
|
63 |
+
byte_stream = BytesIO(model_buffer.getvalue())
|
64 |
+
onnx_model = onnx.load(byte_stream)
|
65 |
+
sess = onnxruntime.InferenceSession(byte_stream.getvalue())
|
66 |
+
encoded_library = library_encoder.transform([library_name])
|
67 |
+
outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library})
|
68 |
+
predicted_description = description_encoder.inverse_transform([outputs[0].argmax()])[0]
|
69 |
+
return predicted_description
|
70 |
+
|
71 |
+
# Streamlit UI
|
72 |
st.title("PyTorch Neural Network Interface")
|
73 |
|
74 |
+
# Model Upload
|
75 |
uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
|
|
|
76 |
if uploaded_file:
|
77 |
byte_stream = BytesIO(uploaded_file.getvalue())
|
78 |
model = onnx.load(byte_stream)
|
79 |
st.write("Model uploaded successfully!")
|
80 |
|
81 |
+
# Model Download
|
82 |
if st.button('Download Model as ONNX'):
|
83 |
buffer = BytesIO()
|
84 |
torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
|
|
|
89 |
mime="application/octet-stream"
|
90 |
)
|
91 |
|
92 |
+
# Default CSV Example
|
93 |
+
DEFAULT_CSV = """
|
94 |
+
library_name,description
|
95 |
+
torch,PyTorch is an open-source machine learning library
|
96 |
+
tensorflow,Open source software library for high performance numerical computations
|
97 |
+
pandas,Data analysis and manipulation tool
|
98 |
+
numpy,Library for numerical computations in Python
|
99 |
+
scikit-learn,Machine learning library in Python
|
100 |
+
"""
|
101 |
|
102 |
+
csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
|
103 |
+
if st.button('Convert CSV to ONNX Neural Net'):
|
104 |
+
df, library_encoder, description_encoder = process_csv(csv_data)
|
105 |
+
model_buffer = train_and_export(df)
|
106 |
+
st.download_button(
|
107 |
+
label="Download ONNX model",
|
108 |
+
data=model_buffer,
|
109 |
+
file_name="model.onnx",
|
110 |
+
mime="application/octet-stream"
|
111 |
+
)
|
112 |
|
113 |
+
# Inference
|
114 |
+
uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
|
115 |
+
library_name_to_infer = st.text_input("Enter a library name for inference:")
|
116 |
+
if uploaded_model and library_name_to_infer:
|
117 |
+
prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_encoder, description_encoder)
|
118 |
+
st.write(f"Predicted description: {prediction}")
|