Created
May 17, 2023 07:27
-
-
Save DoeringChristian/e396d30efeb4a0730c9a5b6888617402 to your computer and use it in GitHub Desktop.
Neural Radiance (Mitsuba3 from)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This follows the tutorial from https://github.com/ciy405x at https://github.com/krafton-ai/neural-radiosity-tutorial-mitsuba3/blob/main/neural_radiosity.ipynb | |
from typing import Union | |
import drjit as dr | |
import mitsuba as mi | |
import mitsuba | |
mi.set_variant("cuda_ad_rgb") | |
import os | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mitsuba.python.ad.integrators.common import ADIntegrator, mis_weight | |
from tqdm import tqdm | |
scene_dict = mi.cornell_box() | |
scene_dict["glass"] = {"type": "conductor"} | |
small_box = scene_dict.pop("small-box") | |
small_box["bsdf"]["id"] = "glass" | |
scene_dict["small-box"] = small_box | |
scene: mi.Scene = mi.load_dict(scene_dict) | |
scene = mi.load_file("./data/scenes/cornell-box/scene.xml") | |
# scene = mi.load_file("./data/scenes/veach-ajar/scene.xml") | |
M = 32 | |
batch_size = 2**14 | |
total_steps = 1000 | |
lr = 5e-4 | |
seed = 42 | |
from tinycudann import Encoding as NGPEncoding | |
class NRFieldOrig(nn.Module): | |
def __init__(self, scene: mi.Scene, width=256, n_hidden=8) -> None: | |
"""Initialize an instance of NRField. | |
Args: | |
bb_min (mi.ScalarBoundingBox3f): minimum point of the bounding box | |
bb_max (mi.ScalarBoundingBox3f): maximum point of the bounding box | |
""" | |
super().__init__() | |
self.bbox = scene.bbox() | |
enc_config = { | |
"otype": "Grid", | |
"type": "Hash", | |
"base_resolution": 16, | |
"n_levels": 8, | |
"n_features_per_level": 4, | |
"log2_hashmap_size": 22, | |
} | |
self.pos_enc = NGPEncoding(3, enc_config) | |
in_size = 3 * 4 + self.pos_enc.n_output_dims | |
hidden_layers = [] | |
for _ in range(n_hidden): | |
hidden_layers.append(nn.Linear(width, width)) | |
hidden_layers.append(nn.ReLU(inplace=True)) | |
self.network = nn.Sequential( | |
nn.Linear(in_size, width), | |
nn.ReLU(inplace=True), | |
*hidden_layers, | |
nn.Linear(width, 3), | |
).to("cuda") | |
def forward(self, si: mi.SurfaceInteraction3f): | |
"""Forward pass for NRField. | |
Args: | |
si (mitsuba.SurfaceInteraction3f): surface interaction | |
bsdf (mitsuba.BSDF): bidirectional scattering distribution function | |
Returns: | |
torch.Tensor | |
""" | |
with dr.suspend_grad(): | |
x = ((si.p - self.bbox.min) / (self.bbox.max - self.bbox.min)).torch() | |
wi = si.to_world(si.wi).torch() | |
n = si.sh_frame.n.torch() | |
f_d = si.bsdf().eval_diffuse_reflectance(si).torch() | |
z_x = self.pos_enc(x) | |
inp = torch.concat([x, wi, n, f_d, z_x], dim=1) | |
out = self.network(inp) | |
out = torch.abs(out) | |
return out.to(torch.float32) | |
class NRFieldSh(nn.Module): | |
@staticmethod | |
def sh_coeffs(n): | |
return (n + 1) ** 2 | |
def __init__(self, scene: mi.Scene, width=256, n_hidden=8, wi_order=5) -> None: | |
"""Initialize an instance of NRField. | |
Args: | |
bb_min (mi.ScalarBoundingBox3f): minimum point of the bounding box | |
bb_max (mi.ScalarBoundingBox3f): maximum point of the bounding box | |
""" | |
super().__init__() | |
self.bbox = scene.bbox() | |
enc_config = { | |
"base_resolution": 16, | |
"n_levels": 8, | |
"n_features_per_level": 4, | |
"log2_hashmap_size": 22, | |
} | |
self.pos_enc = NGPEncoding(3, enc_config) | |
in_size = 3 * 3 + self.pos_enc.n_output_dims + NRFieldSh.sh_coeffs(wi_order) - 1 | |
hidden_layers = [] | |
for _ in range(n_hidden): | |
hidden_layers.append(nn.Linear(width, width)) | |
hidden_layers.append(nn.ReLU(inplace=True)) | |
self.network = nn.Sequential( | |
nn.Linear(in_size, width), | |
nn.ReLU(inplace=True), | |
*hidden_layers, | |
nn.Linear(width, 3), | |
).to("cuda") | |
# self.rgb_net = torch.nn.Sequential(*layers) | |
self.wi_order = wi_order | |
def forward(self, si: mi.SurfaceInteraction3f): | |
"""Forward pass for NRField. | |
Args: | |
si (mitsuba.SurfaceInteraction3f): surface interaction | |
bsdf (mitsuba.BSDF): bidirectional scattering distribution function | |
Returns: | |
torch.Tensor | |
""" | |
with dr.suspend_grad(): | |
x = ((si.p - self.bbox.min) / (self.bbox.max - self.bbox.min)).torch() | |
wi = si.to_world(si.wi) | |
sh_wi = dr.sh_eval(wi, self.wi_order) | |
sh_wi = [sh.torch()[:, None] for sh in sh_wi] | |
sh_wi = torch.concat(sh_wi[1:], dim=1) | |
# wi = si.to_world(si.wi).torch() | |
n = si.sh_frame.n.torch() | |
f_d = si.bsdf().eval_diffuse_reflectance(si).torch() | |
z_x = self.pos_enc(x) | |
inp = [x, sh_wi, n, f_d, z_x] | |
inp = torch.concat(inp, dim=1) | |
out = self.network(inp) | |
out = torch.abs(out) | |
return out.to(torch.float32) | |
class NeradIntegrator(mi.SamplingIntegrator): | |
def __init__(self, model) -> None: | |
super().__init__(mi.Properties()) | |
self.model = model | |
self.l_sampler = mi.load_dict({"type": "independent", "sample_count": 1}) | |
self.r_sampler = mi.load_dict({"type": "independent", "sample_count": 1}) | |
def sample_si( | |
self, | |
scene: mi.Scene, | |
shape_sampler: mi.DiscreteDistribution, | |
sample1, | |
sample2, | |
sample3, | |
active=True, | |
) -> mi.SurfaceInteraction3f: | |
"""Sample a batch of surface interactions with bsdfs. | |
Args: | |
scene (mitsuba.Scene): the underlying scene | |
shape_sampler (mi.DiscreteDistribution): a source of random numbers for shape sampling | |
sample1 (drjit.llvm.ad.Float): determines mesh surfaces | |
sample2 (mitsuba.Point2f): determines positions on the meshes | |
sample3 (mitsuba.Point2f): determines directions at the positions | |
active (bool, optional): mask to specify active lanes. Defaults to True. | |
Returns: | |
tuple of (mitsuba.SurfaceInteraction3f, mitsuba.BSDF) | |
""" | |
shape_index = shape_sampler.sample(sample1, active) | |
shape: mi.Shape = dr.gather(mi.ShapePtr, scene.shapes_dr(), shape_index, active) | |
ps = shape.sample_position(0.5, sample2, active) | |
si = mi.SurfaceInteraction3f(ps, dr.zeros(mi.Color0f)) | |
si.shape = shape | |
bsdf = shape.bsdf() | |
active_two_sided = mi.has_flag(bsdf.flags(), mi.BSDFFlags.BackSide) | |
si.wi = dr.select( | |
active_two_sided, | |
mi.warp.square_to_uniform_sphere(sample3), | |
mi.warp.square_to_uniform_hemisphere(sample3), | |
) | |
si.shape = shape | |
return si | |
def first_non_specular_or_null_si( | |
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, sampler: mi.Sampler | |
): | |
"""Find the first non-specular or null surface interaction. | |
Args: | |
scene (mi.Scene): Scene object. | |
si (mi.SurfaceInteraction3f): Surface interaction. | |
sampler (mi.Sampler): Sampler object. | |
Returns: | |
tuple: A tuple containing four values: | |
- si (mi.SurfaceInteraction3f): First non-specular or null surface interaction. | |
- β (mi.Spectrum): The product of the weights of all previous BSDFs. | |
- null_face (bool): A boolean mask indicating whether the surface is a null face or not. | |
""" | |
# Instead of `bsdf.flags()`, based on `bsdf_sample.sampled_type`. | |
with dr.suspend_grad(): | |
bsdf_ctx = mi.BSDFContext() | |
depth = mi.UInt32(0) | |
β = mi.Spectrum(1) | |
# prev_si = dr.zeros(mi.SurfaceInteraction3f) | |
# prev_bsdf_pdf = mi.Float(1.0) | |
# prev_bsdf_delta = mi.Bool(True) | |
null_face = mi.Bool(True) | |
active = mi.Bool(True) | |
loop = mi.Loop( | |
name="first_non_specular_or_null_si", | |
state=lambda: (sampler, depth, β, active, null_face, si), | |
) | |
loop.set_max_iterations(6) | |
while loop(active): | |
# for i in range(6): | |
# loop invariant: si is located at non-null and Delta surface | |
# if si is located at null or Smooth surface, end loop | |
bsdf: mi.BSDF = si.bsdf() | |
bsdf_sample, bsdf_weight = bsdf.sample( | |
bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active | |
) | |
null_face &= ~mi.has_flag( | |
bsdf_sample.sampled_type, mi.BSDFFlags.BackSide | |
) & (si.wi.z < 0) | |
active &= si.is_valid() & ~null_face | |
active &= mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Glossy) | |
ray = si.spawn_ray(si.to_world(bsdf_sample.wo)) | |
si[active] = scene.ray_intersect( | |
ray, | |
ray_flags=mi.RayFlags.All, | |
coherent=dr.eq(depth, 0), | |
active=active, | |
) | |
β[active] *= bsdf_weight | |
depth[active] += 1 | |
# return si at the first non-specular bounce or null face | |
return si, β, null_face | |
def render_lhs( | |
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, mode: str = "drjit" | |
) -> mi.Color3f | torch.Tensor: | |
""" | |
Renders the left-hand side of the rendering equation by calculating the emitter's radiance and | |
the neural network output at the given surface interaction (si) position and direction (bsdf). | |
Args: | |
scene (mi.Scene): A Mitsuba scene object. | |
model (torch.nn.Module): A neural network model that takes si and bsdf as input and returns | |
a predicted radiance value. | |
si (mi.SurfaceInteraction3f): A Mitsuba surface interaction object. | |
bsdf (mi.BSDF): A Mitsuba BSDF object. | |
Returns: | |
tuple: A tuple containing four values: | |
- L (mi.Spectrum): The total outgoing radiance value. | |
- Le (mi.Spectrum): The emitter's radiance. | |
- out (torch.Tensor): The neural network's predicted radiance value. | |
- mask (torch.Tensor): A boolean tensor indicating which surface interactions are valid. | |
""" | |
with dr.suspend_grad(): | |
Le = si.emitter(scene).eval(si) | |
# discard the null bsdf backside | |
null_face = ~mi.has_flag(si.bsdf().flags(), mi.BSDFFlags.BackSide) & ( | |
si.wi.z < 0 | |
) | |
mask = si.is_valid() & ~null_face | |
out = self.model(si) | |
L = Le + dr.select(mask, mi.Spectrum(out), 0) | |
if mode == "drjit": | |
return Le + dr.select(mask, mi.Spectrum(out), 0) | |
elif mode == "torch": | |
return Le.torch() + out * mask.torch().reshape(-1, 1) | |
def render_rhs( | |
self, scene: mi.Scene, si: mi.SurfaceInteraction3f, sampler, mode="drjit" | |
) -> mi.Color3d | torch.Tensor: | |
with dr.suspend_grad(): | |
bsdf_ctx = mi.BSDFContext() | |
depth = mi.UInt32(0) | |
L = mi.Spectrum(0) | |
β = mi.Spectrum(1) | |
η = mi.Float(1) | |
prev_si = dr.zeros(mi.SurfaceInteraction3f) | |
prev_bsdf_pdf = mi.Float(1.0) | |
prev_bsdf_delta = mi.Bool(True) | |
Le = β * si.emitter(scene).eval(si) | |
bsdf = si.bsdf() | |
# emitter sampling | |
active_next = si.is_valid() | |
active_em = active_next & mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth) | |
ds, em_weight = scene.sample_emitter_direction( | |
si, sampler.next_2d(), True, active_em | |
) | |
active_em &= dr.neq(ds.pdf, 0.0) | |
wo = si.to_local(ds.d) | |
bsdf_value_em, bsdf_pdf_em = bsdf.eval_pdf(bsdf_ctx, si, wo, active_em) | |
mis_em = dr.select(ds.delta, 1, mis_weight(ds.pdf, bsdf_pdf_em)) | |
Lr_dir = β * mis_em * bsdf_value_em * em_weight | |
# bsdf sampling | |
bsdf_sample, bsdf_weight = bsdf.sample( | |
bsdf_ctx, si, sampler.next_1d(), sampler.next_2d(), active_next | |
) | |
# update | |
L = L + Le + Lr_dir | |
ray = si.spawn_ray(si.to_world(bsdf_sample.wo)) | |
η *= bsdf_sample.eta | |
β *= bsdf_weight | |
prev_si = dr.detach(si, True) | |
prev_bsdf_pdf = bsdf_sample.pdf | |
prev_bsdf_delta = mi.has_flag(bsdf_sample.sampled_type, mi.BSDFFlags.Delta) | |
si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True) | |
bsdf = si.bsdf(ray) | |
si, β2, null_face = self.first_non_specular_or_null_si(scene, si, sampler) | |
β *= β2 | |
ds = mi.DirectionSample3f(scene, si=si, ref=prev_si) | |
mis = mis_weight( | |
prev_bsdf_pdf, | |
scene.pdf_emitter_direction(prev_si, ds, ~prev_bsdf_delta), | |
) | |
L += β * mis * si.emitter(scene).eval(si) | |
out = self.model(si) | |
active_nr = ( | |
si.is_valid() | |
& ~null_face | |
& dr.eq(si.emitter(scene).eval(si), mi.Spectrum(0)) | |
) | |
Le = L | |
w_nr = β * mis | |
L = Le + dr.select(active_nr, w_nr * mi.Spectrum(out), 0) | |
if mode == "drjit": | |
return L | |
elif mode == "torch": | |
return Le.torch() + out * dr.select(active_nr, w_nr, 0).torch() | |
def sample( | |
self, | |
scene: mi.Scene, | |
sampler: mi.Sampler, | |
ray: mi.Ray3f, | |
medium: mi.Medium = None, | |
active: bool = True, | |
) -> tuple[mi.Color3f, mi.Bool, list[mi.Color3f]]: | |
self.model.eval() | |
with torch.no_grad(): | |
w, h = list(scene.sensors()[0].film().size()) | |
L = mi.Spectrum(0) | |
ray = mi.Ray3f(dr.detach(ray)) | |
si = scene.ray_intersect(ray, ray_flags=mi.RayFlags.All, coherent=True) | |
bsdf = si.bsdf(ray) | |
# update si and bsdf with the first non-specular ones | |
si, β, _ = self.first_non_specular_or_null_si(scene, si, sampler) | |
L = self.render_lhs(scene, si, mode="drjit") | |
self.model.train() | |
torch.cuda.empty_cache() | |
return β * L, si.is_valid(), [] | |
def train(self): | |
m_area = [] | |
for shape in scene.shapes(): | |
if not shape.is_emitter() and mi.has_flag( | |
shape.bsdf().flags(), mi.BSDFFlags.Smooth | |
): | |
m_area.append(shape.surface_area()) | |
else: | |
m_area.append([0.0]) | |
m_area = np.array(m_area)[:, 0] | |
if len(m_area): | |
m_area /= m_area.sum() | |
else: | |
raise Warning("No smooth shape. No need of neural network training.") | |
shape_sampler = mi.DiscreteDistribution(m_area) | |
optimizer = torch.optim.Adam(field.parameters(), lr=lr) | |
train_losses = [] | |
tqdm_iterator = tqdm(range(total_steps)) | |
self.model.train() | |
for step in tqdm_iterator: | |
optimizer.zero_grad() | |
# detach the computation graph of samplers to avoid lengthy graph of dr.jit | |
r_sampler = self.r_sampler.clone() | |
l_sampler = self.l_sampler.clone() | |
r_sampler.seed(step, batch_size * M // 2) | |
l_sampler.seed(step, batch_size) | |
si_lhs = self.sample_si( | |
scene, | |
shape_sampler, | |
l_sampler.next_1d(), | |
l_sampler.next_2d(), | |
l_sampler.next_2d(), | |
) | |
# copy `si_lhs` M//2 times for RHS evaluation | |
indices = dr.arange(mi.UInt, 0, batch_size) | |
indices = dr.repeat(indices, M // 2) | |
si_rhs = dr.gather(type(si_lhs), si_lhs, indices) | |
# bsdf_rhs = dr.gather(type(bsdf_lhs), bsdf_lhs, indices) | |
# LHS and RHS evaluation | |
lhs = self.render_lhs(scene, si_lhs, mode="torch") | |
# _, Le_rhs, out_rhs, weight_rhs, mask_rhs = render_rhs( | |
# scene, field, si_rhs, _r_sampler | |
# ) | |
rhs = self.render_rhs(scene, si_rhs, r_sampler, mode="torch") | |
# weight_rhs = weight_rhs.torch() * mask_rhs.torch() | |
# lhs = Le_lhs.torch() + out_lhs * mask_lhs.torch().reshape(-1, 1) | |
# rhs = Le_rhs.torch() + out_rhs * weight_rhs | |
rhs = rhs.reshape(batch_size, M // 2, 3).mean(dim=1) | |
norm = 1 | |
# in our experiment, normalization makes rendering biased (dimmer) | |
# norm = (lhs + rhs).detach()/2 + 1e-2 | |
loss = torch.nn.MSELoss()(lhs / norm, rhs / norm) | |
loss.backward() | |
optimizer.step() | |
tqdm_iterator.set_description("Loss %.04f" % (loss.item())) | |
train_losses.append(loss.item()) | |
self.model.eval() | |
self.train_losses = train_losses | |
# optimizer = torch.optim.Adam(field.parameters(), lr=lr) | |
# train_losses = [] | |
# tqdm_iterator = tqdm(range(total_steps)) | |
field = NRFieldOrig(scene, n_hidden=3) | |
integrator = NeradIntegrator(field) | |
integrator.train() | |
image_orig = mi.render(scene, spp=1, integrator=integrator) | |
losses_orig = integrator.train_losses | |
# field = NRFieldSh(scene) | |
# integrator = NeradIntegrator(field) | |
# integrator.train() | |
# image_sh = mi.render(scene, spp=16, integrator=integrator) | |
# losses_sh = integrator.train_losses | |
# field = NRFieldSh(scene, wi_order=2) | |
# integrator = NeradIntegrator(field) | |
# integrator.train() | |
# image_sh_2 = mi.render(scene, spp=16, integrator=integrator) | |
# losses_sh_2 = integrator.train_losses | |
ref_image = mi.render(scene, spp=16) | |
fig, ax = plt.subplots(2, 2, figsize=(10, 10)) | |
fig.patch.set_visible(False) # Hide the figure's background | |
ax[0][0].axis("off") # Remove the axes from the image | |
ax[0][0].imshow(mi.util.convert_to_bitmap(image_orig)) | |
# ax[0][1].axis("off") | |
# ax[0][1].imshow(mi.util.convert_to_bitmap(image_sh)) | |
# ax[0][2].axis("off") | |
# ax[0][2].imshow(mi.util.convert_to_bitmap(image_sh_2)) | |
ax[1][0].axis("off") | |
ax[1][0].imshow(mi.util.convert_to_bitmap(ref_image)) | |
ax[1][1].plot(losses_orig, color="red") | |
# ax[1][1].plot(losses_sh, color="green") | |
# ax[1][1].plot(losses_sh_2, color="yellow") | |
fig.tight_layout() # Remove any extra white spaces around the image | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment