awacke1 commited on
Commit
5ab442a
·
1 Parent(s): 2d0adbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import onnx
8
+ from io import BytesIO
9
+
10
+ class SimpleNN(nn.Module):
11
+ def __init__(self):
12
+ super(SimpleNN, self).__init__()
13
+ self.fc = nn.Linear(28 * 28, 10) # Assuming 28x28 input and 10 classes
14
+
15
+ def forward(self, x):
16
+ x = x.view(-1, 28 * 28)
17
+ x = self.fc(x)
18
+ return x
19
+
20
+ st.title("PyTorch Neural Network Interface")
21
+
22
+ uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
23
+
24
+ if uploaded_file:
25
+ byte_stream = BytesIO(uploaded_file.getvalue())
26
+ model = onnx.load(byte_stream)
27
+ st.write("Model uploaded successfully!")
28
+
29
+ if st.button('Download Model as ONNX'):
30
+ buffer = BytesIO()
31
+ torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
32
+ st.download_button(
33
+ label="Download ONNX model",
34
+ data=buffer,
35
+ file_name="model.onnx",
36
+ mime="application/octet-stream"
37
+ )
38
+
39
+
40
+