diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-05-02 19:46:59 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-02 19:46:59 -0400 |
| commit | 7ef980f79447f8deec2aaf4a501df29f97cf1a39 (patch) | |
| tree | bd4765cb358e25a467f7bbd4f6a07f11ecb1657a | |
| parent | 6b3095758679a7699dc26d1af5521b65ace2cc83 (diff) | |
Fix unzipping logic for inout non-diff parameters and adjust tests (#4090)
* Fix unzipping logic for inout non-diff parameters and adjust tests
+ Removed `-g0` from `struct-this-parameter.slang` test. Works correctly with the new unzipping logic.
+ Removed `-g0` from `was/warped-sampling-1d.slang` test. Works correctly with DX12 & CS_5_1. CS_5_0 appears to run into an FXC compiler bug with detecting infinite loops where there don't appear to be any.
* Update slang-ir-autodiff-unzip.h
* Update warped-sampling-1d.slang
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 28 | ||||
| -rw-r--r-- | tests/autodiff/struct-this-parameter.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/was/warped-sampling-1d.slang | 2 |
3 files changed, 28 insertions, 6 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index c57dc300f..771a3977e 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -372,10 +372,32 @@ struct DiffUnzipPass } else { - // For non differentiable arguments, we can simply pass the argument as is - // if this isn't a `out` parameter, in which case it is removed from propagate call. - if (!as<IROutType>(arg->getDataType())) + if (auto inOutType = as<IRInOutType>(resolvedPrimalFuncType->getParamType(ii))) + { + // For 'inout' parameter we need to create a temp var to hold the value + // before the primal call. This logic is similar to the 'inout' case for differentiable params + // only we don't need to deal with pair types. + // + auto tempPrimalVar = primalBuilder->emitVar(as<IRPtrTypeBase>(arg->getDataType())->getValueType()); + + auto storeUse = findUniqueStoredVal(cast<IRVar>(arg)); + auto storeInst = cast<IRStore>(storeUse->getUser()); + auto storedVal = storeInst->getVal(); + + primalBuilder->emitStore(tempPrimalVar, storedVal); + + diffArgs.add(tempPrimalVar); + } + else + { + // For pure 'in' type. Simply re-use the original argument inst. + // + // For 'out' type parameters, it doesn't really matter what we pass in here, since + // the tranposition logic will discard the argument anyway (we'll pass in the old arg, + // just to keep the number of arguments consistent) + // diffArgs.add(arg); + } } } diff --git a/tests/autodiff/struct-this-parameter.slang b/tests/autodiff/struct-this-parameter.slang index e1526bd4f..9c8ddc724 100644 --- a/tests/autodiff/struct-this-parameter.slang +++ b/tests/autodiff/struct-this-parameter.slang @@ -1,5 +1,5 @@ -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0 -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -g0 +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/was/warped-sampling-1d.slang b/tests/autodiff/was/warped-sampling-1d.slang index 3d4a49267..3a2ca8f92 100644 --- a/tests/autodiff/was/warped-sampling-1d.slang +++ b/tests/autodiff/was/warped-sampling-1d.slang @@ -1,4 +1,4 @@ -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0 +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -profile cs_5_1 -dx12 //TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer RWStructuredBuffer<float> endpointDifferentialBuffer; |
