summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff.cpp28
-rw-r--r--tests/autodiff/cuda-kernel-export-2.slang48
2 files changed, 65 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 1909f860c..ddffd0e21 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -693,16 +693,19 @@ struct AutoDiffPass : public InstPassBase
bool lowerIntermediateContextType(IRBuilder* builder)
{
- bool changed = false;
+ bool result = false;
OrderedHashSet<IRInst*> loweredIntermediateTypes;
// Replace all `BackwardDiffIntermediateContextType` insts with the struct type
// that we generated during backward diff pass.
- processAllInsts([&](IRInst* inst)
- {
- switch (inst->getOp())
+ for (;;)
+ {
+ bool changed = false;
+ processAllInsts([&](IRInst* inst)
{
- case kIROp_BackwardDiffIntermediateContextType:
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
{
auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst);
@@ -719,15 +722,18 @@ struct AutoDiffPass : public InstPassBase
}
}
break;
- default:
- break;
- }
- });
-
+ default:
+ break;
+ }
+ });
+ result |= changed;
+ if (!changed)
+ break;
+ }
// Now we generate the differential type for the intermediate context type
// to allow higher order differentiation.
generateDifferentialImplementationForContextType(loweredIntermediateTypes);
- return changed;
+ return result;
}
// Utility function for topology sorting the intermediate context types.
diff --git a/tests/autodiff/cuda-kernel-export-2.slang b/tests/autodiff/cuda-kernel-export-2.slang
new file mode 100644
index 000000000..9cbb4e881
--- /dev/null
+++ b/tests/autodiff/cuda-kernel-export-2.slang
@@ -0,0 +1,48 @@
+//DISABLE_TEST:SIMPLE: -target cuda -line-directive-mode none
+
+// Verify that we can output a cuda device function with [CudaDeviceExport].
+// Disabled until we have FileCheck.
+
+
+//////////////////////////////////////////////////////////////////////////
+// Lambda GGX
+//////////////////////////////////////////////////////////////////////////
+
+[CudaDeviceExport]
+[BackwardDifferentiable]
+float lambdaGGX(const float alphaSqr, const float cosTheta)
+{
+ const float SPECULAR_EPSILON = 1e-4f;
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
+ float cosThetaSqr = _cosTheta * _cosTheta;
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
+ return 0.5f * (sqrt(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
+}
+
+[CudaDeviceExport]
+void lambdaGGX_bwd(inout DifferentialPair<float> alphaSqr, inout DifferentialPair<float> cosTheta, const float d_out)
+{
+ __bwd_diff(lambdaGGX)(alphaSqr, cosTheta, d_out);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Masking Smith
+//////////////////////////////////////////////////////////////////////////
+
+[CudaDeviceExport]
+[BackwardDifferentiable]
+float maskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
+{
+ float lambdaI = lambdaGGX(alphaSqr, cosThetaI);
+ float lambdaO = lambdaGGX(alphaSqr, cosThetaO);
+ return 1.0f / (1.0f + lambdaI + lambdaO);
+}
+
+[CudaDeviceExport]
+void maskingSmithGGXCorrelated_bwd(inout DifferentialPair<float> alphaSqr,
+ inout DifferentialPair<float> cosThetaI,
+ inout DifferentialPair<float> cosThetaO,
+ const float d_out)
+{
+ __bwd_diff(maskingSmithGGXCorrelated)(alphaSqr, cosThetaI, cosThetaO, d_out);
+}