PorousMediaGAN / app.py
lmoss
added better formatting
9bc9fb6
raw
history blame
3.62 kB
import streamlit as st
import pyvista as pv
from dcgan import DCGAN3D_G
import torch
import requests
import time
import numpy as np
import streamlit.components.v1 as components
st.title("Generating Porous Media with GANs")
st.markdown(
"""
### Author
_Lukas Mosser (2022)_ - :bird:[porestar](https://twitter.com/porestar)
## Description
This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan)
published in Physical Review E ([Mosser et. al 2017](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.96.043309))
The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights.
## The Demo
Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/)
The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance.
Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible.
"""
, unsafe_allow_html=True)
url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/berea_generator_epoch_24.pth?raw=true"
# If repo is private - we need to add a token in header:
resp = requests.get(url)
with open('berea_generator_epoch_24.pth', 'wb') as f:
f.write(resp.content)
pv.set_plot_theme("document")
netG = DCGAN3D_G(64, 512, 1, 32, 1)
netG.load_state_dict(torch.load("berea_generator_epoch_24.pth", map_location=torch.device('cpu')))
z = torch.randn(1, 512, 1, 1, 1)
with torch.no_grad():
X = netG(z)
img = 1-(X[0, 0].numpy()+1)/2
a = 0.9
# create a uniform grid to sample the function with
x_min, y_min, z_min = 0, 0, 0
grid = pv.UniformGrid(
dims=img.shape,
spacing=(1, 1, 1),
origin=(x_min, y_min, z_min),
)
x, y, z = grid.points.T
# sample and plot
values = img.flatten()
grid.point_data['my_array'] = values
slices = grid.slice_orthogonal()
mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points")
dist = np.linalg.norm(mesh.points, axis=1)
pl = pv.Plotter(shape=(1, 1),
window_size=(400, 400))
_ = pl.add_mesh(slices, cmap="gray")
pl.export_html('slices.html')
pl = pv.Plotter(shape=(1, 1),
window_size=(400, 400))
_ = pl.add_mesh(mesh, scalars=dist)
pl.export_html('mesh.html')
view_width = 400
view_height = 400
HtmlFile = open("slices.html", 'r', encoding='utf-8')
source_code = HtmlFile.read()
st.header("3D Intersections")
components.html(source_code, width=view_width, height=view_height)
st.markdown("_Click and drag to spin, right click to shift._")
HtmlFile = open("mesh.html", 'r', encoding='utf-8')
source_code = HtmlFile.read()
st.header("3D Pore Space Mesh")
components.html(source_code, width=view_width, height=view_height)
st.markdown("_Click and drag to spin, right click to shift._")
st.markdown("""
## Citation
If you use our code for your own research, we would be grateful if you cite our publication:
```
@article{pmgan2017,
title={Reconstruction of three-dimensional porous media using generative adversarial neural networks},
author={Mosser, Lukas and Dubrule, Olivier and Blunt, Martin J.},
journal={arXiv preprint arXiv:1704.03225},
year={2017}
}```
""")