Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import onnx | |
from io import BytesIO | |
class SimpleNN(nn.Module): | |
def __init__(self): | |
super(SimpleNN, self).__init__() | |
self.fc = nn.Linear(28 * 28, 10) # Assuming 28x28 input and 10 classes | |
def forward(self, x): | |
x = x.view(-1, 28 * 28) | |
x = self.fc(x) | |
return x | |
st.title("PyTorch Neural Network Interface") | |
uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx") | |
if uploaded_file: | |
byte_stream = BytesIO(uploaded_file.getvalue()) | |
model = onnx.load(byte_stream) | |
st.write("Model uploaded successfully!") | |
if st.button('Download Model as ONNX'): | |
buffer = BytesIO() | |
torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer) | |
st.download_button( | |
label="Download ONNX model", | |
data=buffer, | |
file_name="model.onnx", | |
mime="application/octet-stream" | |
) | |