summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-04 20:07:14 -0800
committerGitHub <noreply@github.com>2023-02-04 20:07:14 -0800
commita12c5511a9003efb23b265a7f2f613cf49aa9f07 (patch)
treeb23bab09ae99df1516a89ac60f9779cf979ff2ef
parent228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (diff)
Patch transcription of `inout` non differentiable params. (#2623)
-rw-r--r--source/slang/slang-check-expr.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp54
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h7
-rw-r--r--tests/autodiff/reverse-inout-param-2.slang3
4 files changed, 61 insertions, 14 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a52a08f15..7e2bc3822 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2232,9 +2232,18 @@ namespace Slang
{
if (as<DifferentialPairType>(derivType))
{
- // Using inout type on all the derivative parameters
+ // An `in` differentiable parameter becomes an `inout` parameter.
derivType = m_astBuilder->getInOutType(derivType);
}
+ else if (auto inoutType = as<InOutType>(derivType))
+ {
+ if (!as<DifferentialPairType>(inoutType->getValueType()))
+ {
+ // An `inout` non differentiable parameter becomes an `in` parameter
+ // (removing `out`).
+ derivType = inoutType->getValueType();
+ }
+ }
type->paramTypes.add(derivType);
}
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 67387e83a..fed53b037 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -269,9 +269,16 @@ namespace Slang
return diffValueType;
}
+ auto maybeConvertInOutTypeToValueType = [](IRType* type)
+ {
+ if (auto inoutType = as<IRInOutType>(type))
+ return inoutType->getValueType();
+ return type;
+ };
+
// If the param is marked as no_diff, return the primal type.
if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
- return primalNoDiffType;
+ return maybeConvertInOutTypeToValueType(primalNoDiffType);
auto diffPairType = tryGetDiffPairType(builder, paramType);
if (diffPairType)
@@ -281,7 +288,7 @@ namespace Slang
return diffPairType;
}
auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
- return primalType;
+ return maybeConvertInOutTypeToValueType(primalType);
}
// Create an empty func to represent the transcribed func of `origFunc`.
@@ -1003,15 +1010,40 @@ namespace Slang
}
else if (!isRelevantDifferentialPair(fwdParam->getDataType()))
{
- // Case 2: non differentiable, non output parameters.
- // If parameter is not an out param and has nothing to do with differentiation,
- // simply move the parameter to the end.
- //
- fwdParam->removeFromParent();
- fwdDiffParameterBlock->addParam(fwdParam);
- result.primalFuncParams.Add(fwdParam);
- result.propagateFuncParams.Add(fwdParam);
- continue;
+ if (inoutType)
+ {
+ // Case 2: non differentiable inout parameter.
+ // They should become an inout parameter in primal func, but an in parameter in
+ // bwd func.
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+
+ primalRefReplacement = fwdParam;
+
+ // Create an in param for the prop func.
+ auto propParam = builder->emitParam(inoutType->getValueType());
+ result.propagateFuncParams.Add(propParam);
+
+ // Create a local var for the out param for the primal part of the prop func.
+ auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType());
+ result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar);
+ auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam);
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar;
+ }
+ else
+ {
+ // Case 3: non differentiable, non output parameters.
+ // If parameter is not an out param and has nothing to do with differentiation,
+ // simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+ result.propagateFuncParams.Add(fwdParam);
+ continue;
+ }
}
else if(!inoutType)
{
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 92bd0b0a8..eef820804 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1040,6 +1040,13 @@ struct DiffTransposePass
args.add(nullptr);
argRequiresLoad.add(false);
}
+ else if (as<IRInOutType>(paramType))
+ {
+ arg = builder->emitLoad(arg);
+ args.add(arg);
+ argTypes.add(arg->getDataType());
+ argRequiresLoad.add(false);
+ }
else
{
args.add(arg);
diff --git a/tests/autodiff/reverse-inout-param-2.slang b/tests/autodiff/reverse-inout-param-2.slang
index 1fa64751b..18eb825e6 100644
--- a/tests/autodiff/reverse-inout-param-2.slang
+++ b/tests/autodiff/reverse-inout-param-2.slang
@@ -56,14 +56,13 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
p.m = 1.0;
p.n = 2.0;
- ND v2 = { 1.0 };
+ let v2 : ND = { 1.0 };
var x = diffPair(5.0);
float yDiffOut = 1.0;
__bwd_diff(f)(p, v2, x, yDiffOut);
- // (3+((3+x)*x))*((3+x)*x) = (3+3x+x^2)*(3x+x^2)
outputBuffer[0] = x.p; // should be 5, since bwd_diff does not write back new primal val.
outputBuffer[1] = x.d; // 14
outputBuffer[2] = p.m; // 1.0