summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-21 12:51:46 -0800
committerGitHub <noreply@github.com>2023-02-21 12:51:46 -0800
commit0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a (patch)
treefd8f2b6e528e01a90cc2f34b2fe8ebf6cc5f97a9 /source
parent6bca0ec355aae2955c7de1cd16c2dc0dfe46f19c (diff)
Fix transposeCall. (#2669)
* Modify control-flow test case * Update reverse-control-flow-3.slang * Fix `transposeCall`. * Fix. --------- Co-authored-by: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h17
1 files changed, 12 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index f998ae13f..70018b476 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1356,6 +1356,8 @@ struct DiffTransposePass
// to inform us this case. Here we just need to generate a load of the derivative variable
// and use it as the final argument.
args.add(builder->emitLoad(arg->getOperand(0)));
+ argTypes.add(args.getLast()->getDataType());
+ argRequiresLoad.add(false);
}
else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg))
{
@@ -1369,6 +1371,8 @@ struct DiffTransposePass
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
builder->emitStore(tempVar, pairVal);
args.add(tempVar);
+ argTypes.add(builder->getInOutType(pairType));
+ argRequiresLoad.add(false);
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
}
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
@@ -1384,17 +1388,20 @@ struct DiffTransposePass
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
+ auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType());
+ auto diffZero = builder->emitCallInst(
+ diffType,
+ diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
+ List<IRInst*>());
+
// Initialize this var to (arg.primal, 0).
builder->emitStore(
var,
builder->emitMakeDifferentialPair(
arg->getDataType(),
makePairArg->getPrimalValue(),
- builder->emitCallInst(
- (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()),
- diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
- List<IRInst*>())));
-
+ diffZero));
+
args.add(var);
argTypes.add(builder->getInOutType(pairType));
argRequiresLoad.add(true);