summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp54
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h7
3 files changed, 60 insertions, 12 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a52a08f15..7e2bc3822 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2232,9 +2232,18 @@ namespace Slang
{
if (as<DifferentialPairType>(derivType))
{
- // Using inout type on all the derivative parameters
+ // An `in` differentiable parameter becomes an `inout` parameter.
derivType = m_astBuilder->getInOutType(derivType);
}
+ else if (auto inoutType = as<InOutType>(derivType))
+ {
+ if (!as<DifferentialPairType>(inoutType->getValueType()))
+ {
+ // An `inout` non differentiable parameter becomes an `in` parameter
+ // (removing `out`).
+ derivType = inoutType->getValueType();
+ }
+ }
type->paramTypes.add(derivType);
}
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 67387e83a..fed53b037 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -269,9 +269,16 @@ namespace Slang
return diffValueType;
}
+ auto maybeConvertInOutTypeToValueType = [](IRType* type)
+ {
+ if (auto inoutType = as<IRInOutType>(type))
+ return inoutType->getValueType();
+ return type;
+ };
+
// If the param is marked as no_diff, return the primal type.
if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
- return primalNoDiffType;
+ return maybeConvertInOutTypeToValueType(primalNoDiffType);
auto diffPairType = tryGetDiffPairType(builder, paramType);
if (diffPairType)
@@ -281,7 +288,7 @@ namespace Slang
return diffPairType;
}
auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
- return primalType;
+ return maybeConvertInOutTypeToValueType(primalType);
}
// Create an empty func to represent the transcribed func of `origFunc`.
@@ -1003,15 +1010,40 @@ namespace Slang
}
else if (!isRelevantDifferentialPair(fwdParam->getDataType()))
{
- // Case 2: non differentiable, non output parameters.
- // If parameter is not an out param and has nothing to do with differentiation,
- // simply move the parameter to the end.
- //
- fwdParam->removeFromParent();
- fwdDiffParameterBlock->addParam(fwdParam);
- result.primalFuncParams.Add(fwdParam);
- result.propagateFuncParams.Add(fwdParam);
- continue;
+ if (inoutType)
+ {
+ // Case 2: non differentiable inout parameter.
+ // They should become an inout parameter in primal func, but an in parameter in
+ // bwd func.
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+
+ primalRefReplacement = fwdParam;
+
+ // Create an in param for the prop func.
+ auto propParam = builder->emitParam(inoutType->getValueType());
+ result.propagateFuncParams.Add(propParam);
+
+ // Create a local var for the out param for the primal part of the prop func.
+ auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType());
+ result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar);
+ auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam);
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar;
+ }
+ else
+ {
+ // Case 3: non differentiable, non output parameters.
+ // If parameter is not an out param and has nothing to do with differentiation,
+ // simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+ result.propagateFuncParams.Add(fwdParam);
+ continue;
+ }
}
else if(!inoutType)
{
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 92bd0b0a8..eef820804 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1040,6 +1040,13 @@ struct DiffTransposePass
args.add(nullptr);
argRequiresLoad.add(false);
}
+ else if (as<IRInOutType>(paramType))
+ {
+ arg = builder->emitLoad(arg);
+ args.add(arg);
+ argTypes.add(arg->getDataType());
+ argRequiresLoad.add(false);
+ }
else
{
args.add(arg);