awacke1 commited on
Commit
3c555d2
·
1 Parent(s): 4abc51f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -4
app.py CHANGED
@@ -1,31 +1,84 @@
 
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- import torchvision.transforms as transforms
6
- from PIL import Image
7
  import onnx
 
 
 
8
  from io import BytesIO
9
 
 
10
  class SimpleNN(nn.Module):
11
  def __init__(self):
12
  super(SimpleNN, self).__init__()
13
- self.fc = nn.Linear(28 * 28, 10) # Assuming 28x28 input and 10 classes
14
 
15
  def forward(self, x):
16
  x = x.view(-1, 28 * 28)
17
  x = self.fc(x)
18
  return x
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  st.title("PyTorch Neural Network Interface")
21
 
 
22
  uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
23
-
24
  if uploaded_file:
25
  byte_stream = BytesIO(uploaded_file.getvalue())
26
  model = onnx.load(byte_stream)
27
  st.write("Model uploaded successfully!")
28
 
 
29
  if st.button('Download Model as ONNX'):
30
  buffer = BytesIO()
31
  torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
@@ -36,5 +89,30 @@ if st.button('Download Model as ONNX'):
36
  mime="application/octet-stream"
37
  )
38
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
  import streamlit as st
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
 
 
7
  import onnx
8
+ import onnxruntime
9
+ import pandas as pd
10
+ from sklearn.preprocessing import LabelEncoder
11
  from io import BytesIO
12
 
13
+ # Define a simple neural network
14
  class SimpleNN(nn.Module):
15
  def __init__(self):
16
  super(SimpleNN, self).__init__()
17
+ self.fc = nn.Linear(28 * 28, 10)
18
 
19
  def forward(self, x):
20
  x = x.view(-1, 28 * 28)
21
  x = self.fc(x)
22
  return x
23
 
24
+ # Neural network for the CSV data
25
+ class EmbeddingNN(nn.Module):
26
+ def __init__(self, num_libraries, num_descriptions, embedding_dim=10):
27
+ super(EmbeddingNN, self).__init__()
28
+ self.embedding = nn.Embedding(num_libraries, embedding_dim)
29
+ self.fc = nn.Linear(embedding_dim, num_descriptions)
30
+
31
+ def forward(self, x):
32
+ x = self.embedding(x)
33
+ x = self.fc(x)
34
+ return x
35
+
36
+ def process_csv(csv_data):
37
+ df = pd.read_csv(StringIO(csv_data))
38
+ library_encoder = LabelEncoder()
39
+ description_encoder = LabelEncoder()
40
+ df['library_encoded'] = library_encoder.fit_transform(df['library_name'])
41
+ df['description_encoded'] = description_encoder.fit_transform(df['description'])
42
+ return df, library_encoder, description_encoder
43
+
44
+ def train_and_export(df):
45
+ model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique()))
46
+ criterion = nn.CrossEntropyLoss()
47
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
48
+
49
+ for epoch in range(50):
50
+ inputs = torch.tensor(df['library_encoded'].values, dtype=torch.long)
51
+ labels = torch.tensor(df['description_encoded'].values, dtype=torch.long)
52
+ optimizer.zero_grad()
53
+ outputs = model(inputs)
54
+ loss = criterion(outputs, labels)
55
+ loss.backward()
56
+ optimizer.step()
57
+
58
+ buffer = BytesIO()
59
+ torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer)
60
+ return buffer
61
+
62
+ def infer_from_onnx(model_buffer, library_name, library_encoder, description_encoder):
63
+ byte_stream = BytesIO(model_buffer.getvalue())
64
+ onnx_model = onnx.load(byte_stream)
65
+ sess = onnxruntime.InferenceSession(byte_stream.getvalue())
66
+ encoded_library = library_encoder.transform([library_name])
67
+ outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library})
68
+ predicted_description = description_encoder.inverse_transform([outputs[0].argmax()])[0]
69
+ return predicted_description
70
+
71
+ # Streamlit UI
72
  st.title("PyTorch Neural Network Interface")
73
 
74
+ # Model Upload
75
  uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
 
76
  if uploaded_file:
77
  byte_stream = BytesIO(uploaded_file.getvalue())
78
  model = onnx.load(byte_stream)
79
  st.write("Model uploaded successfully!")
80
 
81
+ # Model Download
82
  if st.button('Download Model as ONNX'):
83
  buffer = BytesIO()
84
  torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
 
89
  mime="application/octet-stream"
90
  )
91
 
92
+ # Default CSV Example
93
+ DEFAULT_CSV = """
94
+ library_name,description
95
+ torch,PyTorch is an open-source machine learning library
96
+ tensorflow,Open source software library for high performance numerical computations
97
+ pandas,Data analysis and manipulation tool
98
+ numpy,Library for numerical computations in Python
99
+ scikit-learn,Machine learning library in Python
100
+ """
101
 
102
+ csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV)
103
+ if st.button('Convert CSV to ONNX Neural Net'):
104
+ df, library_encoder, description_encoder = process_csv(csv_data)
105
+ model_buffer = train_and_export(df)
106
+ st.download_button(
107
+ label="Download ONNX model",
108
+ data=model_buffer,
109
+ file_name="model.onnx",
110
+ mime="application/octet-stream"
111
+ )
112
 
113
+ # Inference
114
+ uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx")
115
+ library_name_to_infer = st.text_input("Enter a library name for inference:")
116
+ if uploaded_model and library_name_to_infer:
117
+ prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_encoder, description_encoder)
118
+ st.write(f"Predicted description: {prediction}")