From 03c10833beb331e234554808c2a80d3cadecc7c0 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 24 Mar 2023 04:33:51 -0700 Subject: Fix nested bwdContextType lowering. (#2731) Co-authored-by: Yong He --- tests/autodiff/cuda-kernel-export-2.slang | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/autodiff/cuda-kernel-export-2.slang (limited to 'tests') diff --git a/tests/autodiff/cuda-kernel-export-2.slang b/tests/autodiff/cuda-kernel-export-2.slang new file mode 100644 index 000000000..9cbb4e881 --- /dev/null +++ b/tests/autodiff/cuda-kernel-export-2.slang @@ -0,0 +1,48 @@ +//DISABLE_TEST:SIMPLE: -target cuda -line-directive-mode none + +// Verify that we can output a cuda device function with [CudaDeviceExport]. +// Disabled until we have FileCheck. + + +////////////////////////////////////////////////////////////////////////// +// Lambda GGX +////////////////////////////////////////////////////////////////////////// + +[CudaDeviceExport] +[BackwardDifferentiable] +float lambdaGGX(const float alphaSqr, const float cosTheta) +{ + const float SPECULAR_EPSILON = 1e-4f; + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + return 0.5f * (sqrt(1.0f + alphaSqr * tanThetaSqr) - 1.0f); +} + +[CudaDeviceExport] +void lambdaGGX_bwd(inout DifferentialPair alphaSqr, inout DifferentialPair cosTheta, const float d_out) +{ + __bwd_diff(lambdaGGX)(alphaSqr, cosTheta, d_out); +} + +////////////////////////////////////////////////////////////////////////// +// Masking Smith +////////////////////////////////////////////////////////////////////////// + +[CudaDeviceExport] +[BackwardDifferentiable] +float maskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = lambdaGGX(alphaSqr, cosThetaI); + float lambdaO = lambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +[CudaDeviceExport] +void maskingSmithGGXCorrelated_bwd(inout DifferentialPair alphaSqr, + inout DifferentialPair cosThetaI, + inout DifferentialPair cosThetaO, + const float d_out) +{ + __bwd_diff(maskingSmithGGXCorrelated)(alphaSqr, cosThetaI, cosThetaO, d_out); +} -- cgit v1.2.3