import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from PIL import Image import onnx from io import BytesIO class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(28 * 28, 10) # Assuming 28x28 input and 10 classes def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc(x) return x st.title("PyTorch Neural Network Interface") uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx") if uploaded_file: byte_stream = BytesIO(uploaded_file.getvalue()) model = onnx.load(byte_stream) st.write("Model uploaded successfully!") if st.button('Download Model as ONNX'): buffer = BytesIO() torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer) st.download_button( label="Download ONNX model", data=buffer, file_name="model.onnx", mime="application/octet-stream" )