File size: 1,022 Bytes
4abc51f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import onnx
from io import BytesIO

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(28 * 28, 10)  # Assuming 28x28 input and 10 classes

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc(x)
        return x

st.title("PyTorch Neural Network Interface")

uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")

if uploaded_file:
    byte_stream = BytesIO(uploaded_file.getvalue())
    model = onnx.load(byte_stream)
    st.write("Model uploaded successfully!")

if st.button('Download Model as ONNX'):
    buffer = BytesIO()
    torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
    st.download_button(
        label="Download ONNX model",
        data=buffer,
        file_name="model.onnx",
        mime="application/octet-stream"
    )