summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp3
-rw-r--r--source/slang/slang-ir-autodiff.cpp19
-rw-r--r--tests/autodiff/nodiff-ptr.slang40
-rw-r--r--tests/autodiff/nodiff-ptr.slang.expected.txt6
5 files changed, 69 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index f0ac428c7..36093518a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
{
// If primal parameter is mutable, we need to pass in a temp var.
auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
- if (primalParamPtrType->getOp() == kIROp_InOutType)
- {
- // If the primal parameter is inout, we need to set the initial value.
- builder.emitStore(tempVar, primalArg);
- }
+
+ // We also need to setup the initial value of the temp var, otherwise
+ // the temp var will be uninitialized which could cause undefined behavior
+ // in the primal function.
+ builder.emitStore(tempVar, primalArg);
+
primalArgs.add(tempVar);
}
else
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index ada35689c..1b3825a7d 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
+ if (isNoDiffType(originalType))
+ return nullptr;
+
if (auto origPtrType = as<IRPtrTypeBase>(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 5c05b0811..4edd8eabe 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType(
bool isNoDiffType(IRType* paramType)
{
- while (auto ptrType = as<IRPtrTypeBase>(paramType))
- paramType = ptrType->getValueType();
- while (auto attrType = as<IRAttributedType>(paramType))
+ while (paramType)
{
- if (attrType->findAttr<IRNoDiffAttr>())
+ if (auto attrType = as<IRAttributedType>(paramType))
{
- return true;
+ if (attrType->findAttr<IRNoDiffAttr>())
+ return true;
+
+ paramType = attrType->getBaseType();
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(paramType))
+ {
+ paramType = ptrType->getValueType();
+ }
+ else
+ {
+ return false;
}
}
return false;
diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang
new file mode 100644
index 000000000..d20abddac
--- /dev/null
+++ b/tests/autodiff/nodiff-ptr.slang
@@ -0,0 +1,40 @@
+
+[Differentiable]
+float sumOfSquares(float x, float y, no_diff float4* test)
+{
+ return x * x + y * y * (test->x + test->y + test->z);
+}
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly
+
+//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4)
+uniform float* ptr;
+
+//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ float4* testPtr = (float4*)ptr;
+
+ let result = sumOfSquares(2.0, 3.0, testPtr);
+
+ // Use forward differentiation to compute the gradient of the output w.r.t. x only.
+ let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr);
+
+ // Create a differentiable pair to pass in the primal value and to receive the gradient.
+ var dpX = diffPair(2.0);
+ var dpY = diffPair(3.0);
+
+ // Propagate the gradient of the output (1.0f) to the input parameters.
+ bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0);
+
+ outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58
+ outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4
+ outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58
+ outputBuffer[3] = dpX.d; // 2*x = 4
+
+ outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36
+}
diff --git a/tests/autodiff/nodiff-ptr.slang.expected.txt b/tests/autodiff/nodiff-ptr.slang.expected.txt
new file mode 100644
index 000000000..959cc68e4
--- /dev/null
+++ b/tests/autodiff/nodiff-ptr.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+58.000000
+4.000000
+58.000000
+4.000000
+36.000000