diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-04 20:07:14 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-04 20:07:14 -0800 |
| commit | a12c5511a9003efb23b265a7f2f613cf49aa9f07 (patch) | |
| tree | b23bab09ae99df1516a89ac60f9779cf979ff2ef | |
| parent | 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (diff) | |
Patch transcription of `inout` non differentiable params. (#2623)
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 7 | ||||
| -rw-r--r-- | tests/autodiff/reverse-inout-param-2.slang | 3 |
4 files changed, 61 insertions, 14 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a52a08f15..7e2bc3822 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2232,9 +2232,18 @@ namespace Slang { if (as<DifferentialPairType>(derivType)) { - // Using inout type on all the derivative parameters + // An `in` differentiable parameter becomes an `inout` parameter. derivType = m_astBuilder->getInOutType(derivType); } + else if (auto inoutType = as<InOutType>(derivType)) + { + if (!as<DifferentialPairType>(inoutType->getValueType())) + { + // An `inout` non differentiable parameter becomes an `in` parameter + // (removing `out`). + derivType = inoutType->getValueType(); + } + } type->paramTypes.add(derivType); } } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 67387e83a..fed53b037 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -269,9 +269,16 @@ namespace Slang return diffValueType; } + auto maybeConvertInOutTypeToValueType = [](IRType* type) + { + if (auto inoutType = as<IRInOutType>(type)) + return inoutType->getValueType(); + return type; + }; + // If the param is marked as no_diff, return the primal type. if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) - return primalNoDiffType; + return maybeConvertInOutTypeToValueType(primalNoDiffType); auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) @@ -281,7 +288,7 @@ namespace Slang return diffPairType; } auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); - return primalType; + return maybeConvertInOutTypeToValueType(primalType); } // Create an empty func to represent the transcribed func of `origFunc`. @@ -1003,15 +1010,40 @@ namespace Slang } else if (!isRelevantDifferentialPair(fwdParam->getDataType())) { - // Case 2: non differentiable, non output parameters. - // If parameter is not an out param and has nothing to do with differentiation, - // simply move the parameter to the end. - // - fwdParam->removeFromParent(); - fwdDiffParameterBlock->addParam(fwdParam); - result.primalFuncParams.Add(fwdParam); - result.propagateFuncParams.Add(fwdParam); - continue; + if (inoutType) + { + // Case 2: non differentiable inout parameter. + // They should become an inout parameter in primal func, but an in parameter in + // bwd func. + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); + result.primalFuncParams.Add(fwdParam); + + primalRefReplacement = fwdParam; + + // Create an in param for the prop func. + auto propParam = builder->emitParam(inoutType->getValueType()); + result.propagateFuncParams.Add(propParam); + + // Create a local var for the out param for the primal part of the prop func. + auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType()); + result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar); + auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam); + result.propagateFuncSpecificPrimalInsts.add(storeInst); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; + } + else + { + // Case 3: non differentiable, non output parameters. + // If parameter is not an out param and has nothing to do with differentiation, + // simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); + result.primalFuncParams.Add(fwdParam); + result.propagateFuncParams.Add(fwdParam); + continue; + } } else if(!inoutType) { diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 92bd0b0a8..eef820804 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1040,6 +1040,13 @@ struct DiffTransposePass args.add(nullptr); argRequiresLoad.add(false); } + else if (as<IRInOutType>(paramType)) + { + arg = builder->emitLoad(arg); + args.add(arg); + argTypes.add(arg->getDataType()); + argRequiresLoad.add(false); + } else { args.add(arg); diff --git a/tests/autodiff/reverse-inout-param-2.slang b/tests/autodiff/reverse-inout-param-2.slang index 1fa64751b..18eb825e6 100644 --- a/tests/autodiff/reverse-inout-param-2.slang +++ b/tests/autodiff/reverse-inout-param-2.slang @@ -56,14 +56,13 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) p.m = 1.0; p.n = 2.0; - ND v2 = { 1.0 }; + let v2 : ND = { 1.0 }; var x = diffPair(5.0); float yDiffOut = 1.0; __bwd_diff(f)(p, v2, x, yDiffOut); - // (3+((3+x)*x))*((3+x)*x) = (3+3x+x^2)*(3x+x^2) outputBuffer[0] = x.p; // should be 5, since bwd_diff does not write back new primal val. outputBuffer[1] = x.d; // 14 outputBuffer[2] = p.m; // 1.0 |
