awacke1's picture
Create v1.app.py
4abc51f
raw
history blame
1.02 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import onnx
from io import BytesIO
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc = nn.Linear(28 * 28, 10) # Assuming 28x28 input and 10 classes
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc(x)
return x
st.title("PyTorch Neural Network Interface")
uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx")
if uploaded_file:
byte_stream = BytesIO(uploaded_file.getvalue())
model = onnx.load(byte_stream)
st.write("Model uploaded successfully!")
if st.button('Download Model as ONNX'):
buffer = BytesIO()
torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer)
st.download_button(
label="Download ONNX model",
data=buffer,
file_name="model.onnx",
mime="application/octet-stream"
)