summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h17
1 files changed, 12 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index f998ae13f..70018b476 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1356,6 +1356,8 @@ struct DiffTransposePass
// to inform us this case. Here we just need to generate a load of the derivative variable
// and use it as the final argument.
args.add(builder->emitLoad(arg->getOperand(0)));
+ argTypes.add(args.getLast()->getDataType());
+ argRequiresLoad.add(false);
}
else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg))
{
@@ -1369,6 +1371,8 @@ struct DiffTransposePass
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
builder->emitStore(tempVar, pairVal);
args.add(tempVar);
+ argTypes.add(builder->getInOutType(pairType));
+ argRequiresLoad.add(false);
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
}
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
@@ -1384,17 +1388,20 @@ struct DiffTransposePass
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
+ auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType());
+ auto diffZero = builder->emitCallInst(
+ diffType,
+ diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
+ List<IRInst*>());
+
// Initialize this var to (arg.primal, 0).
builder->emitStore(
var,
builder->emitMakeDifferentialPair(
arg->getDataType(),
makePairArg->getPrimalValue(),
- builder->emitCallInst(
- (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()),
- diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
- List<IRInst*>())));
-
+ diffZero));
+
args.add(var);
argTypes.add(builder->getInOutType(pairType));
argRequiresLoad.add(true);