diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-24 04:33:51 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-24 04:33:51 -0700 |
| commit | 03c10833beb331e234554808c2a80d3cadecc7c0 (patch) | |
| tree | b135201bfb1128409739405ca508a01922a97333 | |
| parent | 56a84a06488afb817f79fbd99e8b470bd587ccd1 (diff) | |
Fix nested bwdContextType lowering. (#2731)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 28 | ||||
| -rw-r--r-- | tests/autodiff/cuda-kernel-export-2.slang | 48 |
2 files changed, 65 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 1909f860c..ddffd0e21 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -693,16 +693,19 @@ struct AutoDiffPass : public InstPassBase bool lowerIntermediateContextType(IRBuilder* builder) { - bool changed = false; + bool result = false; OrderedHashSet<IRInst*> loweredIntermediateTypes; // Replace all `BackwardDiffIntermediateContextType` insts with the struct type // that we generated during backward diff pass. - processAllInsts([&](IRInst* inst) - { - switch (inst->getOp()) + for (;;) + { + bool changed = false; + processAllInsts([&](IRInst* inst) { - case kIROp_BackwardDiffIntermediateContextType: + switch (inst->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: { auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst); @@ -719,15 +722,18 @@ struct AutoDiffPass : public InstPassBase } } break; - default: - break; - } - }); - + default: + break; + } + }); + result |= changed; + if (!changed) + break; + } // Now we generate the differential type for the intermediate context type // to allow higher order differentiation. generateDifferentialImplementationForContextType(loweredIntermediateTypes); - return changed; + return result; } // Utility function for topology sorting the intermediate context types. diff --git a/tests/autodiff/cuda-kernel-export-2.slang b/tests/autodiff/cuda-kernel-export-2.slang new file mode 100644 index 000000000..9cbb4e881 --- /dev/null +++ b/tests/autodiff/cuda-kernel-export-2.slang @@ -0,0 +1,48 @@ +//DISABLE_TEST:SIMPLE: -target cuda -line-directive-mode none + +// Verify that we can output a cuda device function with [CudaDeviceExport]. +// Disabled until we have FileCheck. + + +////////////////////////////////////////////////////////////////////////// +// Lambda GGX +////////////////////////////////////////////////////////////////////////// + +[CudaDeviceExport] +[BackwardDifferentiable] +float lambdaGGX(const float alphaSqr, const float cosTheta) +{ + const float SPECULAR_EPSILON = 1e-4f; + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + return 0.5f * (sqrt(1.0f + alphaSqr * tanThetaSqr) - 1.0f); +} + +[CudaDeviceExport] +void lambdaGGX_bwd(inout DifferentialPair<float> alphaSqr, inout DifferentialPair<float> cosTheta, const float d_out) +{ + __bwd_diff(lambdaGGX)(alphaSqr, cosTheta, d_out); +} + +////////////////////////////////////////////////////////////////////////// +// Masking Smith +////////////////////////////////////////////////////////////////////////// + +[CudaDeviceExport] +[BackwardDifferentiable] +float maskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = lambdaGGX(alphaSqr, cosThetaI); + float lambdaO = lambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +[CudaDeviceExport] +void maskingSmithGGXCorrelated_bwd(inout DifferentialPair<float> alphaSqr, + inout DifferentialPair<float> cosThetaI, + inout DifferentialPair<float> cosThetaO, + const float d_out) +{ + __bwd_diff(maskingSmithGGXCorrelated)(alphaSqr, cosThetaI, cosThetaO, d_out); +} |
