Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import onnx | |
import onnxruntime | |
import pandas as pd | |
from io import BytesIO | |
# Define a simple neural network | |
class SimpleNN(nn.Module): | |
def __init__(self): | |
super(SimpleNN, self).__init__() | |
self.fc = nn.Linear(28 * 28, 10) | |
def forward(self, x): | |
x = x.view(-1, 28 * 28) | |
x = self.fc(x) | |
return x | |
# Neural network for the CSV data | |
class EmbeddingNN(nn.Module): | |
def __init__(self, num_libraries, num_descriptions, embedding_dim=10): | |
super(EmbeddingNN, self).__init__() | |
self.embedding = nn.Embedding(num_libraries, embedding_dim) | |
self.fc = nn.Linear(embedding_dim, num_descriptions) | |
def forward(self, x): | |
x = self.embedding(x) | |
x = self.fc(x) | |
return x | |
def process_csv(csv_data): | |
df = pd.read_csv(StringIO(csv_data)) | |
# Replace LabelEncoder with custom encoding using pandas factorize | |
df['library_encoded'], library_classes = df['library_name'].factorize() | |
df['description_encoded'], description_classes = df['description'].factorize() | |
return df, library_classes, description_classes | |
def train_and_export(df): | |
model = EmbeddingNN(len(df['library_encoded'].unique()), len(df['description_encoded'].unique())) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
for epoch in range(50): | |
inputs = torch.tensor(df['library_encoded'].values, dtype=torch.long) | |
labels = torch.tensor(df['description_encoded'].values, dtype=torch.long) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
buffer = BytesIO() | |
torch.onnx.export(model, torch.tensor([0], dtype=torch.long), buffer) | |
return buffer | |
def infer_from_onnx(model_buffer, library_name, library_classes, description_classes): | |
byte_stream = BytesIO(model_buffer.getvalue()) | |
onnx_model = onnx.load(byte_stream) | |
sess = onnxruntime.InferenceSession(byte_stream.getvalue()) | |
# Replace transform with custom encoding | |
encoded_library = torch.tensor([list(library_classes).index(library_name)], dtype=torch.long) | |
outputs = sess.run(None, {sess.get_inputs()[0].name: encoded_library.numpy()}) | |
predicted_description = description_classes[outputs[0].argmax()] | |
return predicted_description | |
# Streamlit UI | |
st.title("PyTorch Neural Network Interface") | |
# Model Upload | |
uploaded_file = st.file_uploader("Choose an ONNX model file", type="onnx") | |
if uploaded_file: | |
byte_stream = BytesIO(uploaded_file.getvalue()) | |
model = onnx.load(byte_stream) | |
st.write("Model uploaded successfully!") | |
# Model Download | |
if st.button('Download Model as ONNX'): | |
buffer = BytesIO() | |
torch.onnx.export(SimpleNN(), torch.randn(1, 28, 28), buffer) | |
st.download_button( | |
label="Download ONNX model", | |
data=buffer, | |
file_name="model.onnx", | |
mime="application/octet-stream" | |
) | |
# Default CSV Example | |
DEFAULT_CSV = """ | |
library_name,description | |
torch,PyTorch is an open-source machine learning library | |
tensorflow,Open source software library for high performance numerical computations | |
pandas,Data analysis and manipulation tool | |
numpy,Library for numerical computations in Python | |
""" | |
csv_data = st.text_area("Paste your CSV data here:", value=DEFAULT_CSV) | |
if st.button('Convert CSV to ONNX Neural Net'): | |
df, library_classes, description_classes = process_csv(csv_data) | |
model_buffer = train_and_export(df) | |
st.download_button( | |
label="Download ONNX model", | |
data=model_buffer, | |
file_name="model.onnx", | |
mime="application/octet-stream" | |
) | |
# Inference | |
uploaded_model = st.file_uploader("Choose an ONNX model file for inference", type="onnx") | |
library_name_to_infer = st.text_input("Enter a library name for inference:") | |
if uploaded_model and library_name_to_infer: | |
prediction = infer_from_onnx(uploaded_model, library_name_to_infer, library_classes, description_classes) | |
st.write(f"Predicted description: {prediction}") | |