summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h28
-rw-r--r--tests/autodiff/struct-this-parameter.slang4
-rw-r--r--tests/autodiff/was/warped-sampling-1d.slang2
3 files changed, 28 insertions, 6 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index c57dc300f..771a3977e 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -372,10 +372,32 @@ struct DiffUnzipPass
}
else
{
- // For non differentiable arguments, we can simply pass the argument as is
- // if this isn't a `out` parameter, in which case it is removed from propagate call.
- if (!as<IROutType>(arg->getDataType()))
+ if (auto inOutType = as<IRInOutType>(resolvedPrimalFuncType->getParamType(ii)))
+ {
+ // For 'inout' parameter we need to create a temp var to hold the value
+ // before the primal call. This logic is similar to the 'inout' case for differentiable params
+ // only we don't need to deal with pair types.
+ //
+ auto tempPrimalVar = primalBuilder->emitVar(as<IRPtrTypeBase>(arg->getDataType())->getValueType());
+
+ auto storeUse = findUniqueStoredVal(cast<IRVar>(arg));
+ auto storeInst = cast<IRStore>(storeUse->getUser());
+ auto storedVal = storeInst->getVal();
+
+ primalBuilder->emitStore(tempPrimalVar, storedVal);
+
+ diffArgs.add(tempPrimalVar);
+ }
+ else
+ {
+ // For pure 'in' type. Simply re-use the original argument inst.
+ //
+ // For 'out' type parameters, it doesn't really matter what we pass in here, since
+ // the tranposition logic will discard the argument anyway (we'll pass in the old arg,
+ // just to keep the number of arguments consistent)
+ //
diffArgs.add(arg);
+ }
}
}
diff --git a/tests/autodiff/struct-this-parameter.slang b/tests/autodiff/struct-this-parameter.slang
index e1526bd4f..9c8ddc724 100644
--- a/tests/autodiff/struct-this-parameter.slang
+++ b/tests/autodiff/struct-this-parameter.slang
@@ -1,5 +1,5 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -g0
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/was/warped-sampling-1d.slang b/tests/autodiff/was/warped-sampling-1d.slang
index 3d4a49267..3a2ca8f92 100644
--- a/tests/autodiff/was/warped-sampling-1d.slang
+++ b/tests/autodiff/was/warped-sampling-1d.slang
@@ -1,4 +1,4 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -profile cs_5_1 -dx12
//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer
RWStructuredBuffer<float> endpointDifferentialBuffer;