From 0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Feb 2023 12:51:46 -0800 Subject: Fix transposeCall. (#2669) * Modify control-flow test case * Update reverse-control-flow-3.slang * Fix `transposeCall`. * Fix. --------- Co-authored-by: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-transpose.h | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f998ae13f..70018b476 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1356,6 +1356,8 @@ struct DiffTransposePass // to inform us this case. Here we just need to generate a load of the derivative variable // and use it as the final argument. args.add(builder->emitLoad(arg->getOperand(0))); + argTypes.add(args.getLast()->getDataType()); + argRequiresLoad.add(false); } else if (auto instPair = as(arg)) { @@ -1369,6 +1371,8 @@ struct DiffTransposePass auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); builder->emitStore(tempVar, pairVal); args.add(tempVar); + argTypes.add(builder->getInOutType(pairType)); + argRequiresLoad.add(false); writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); } else if (!as(arg->getDataType()) && getDiffPairType(arg->getDataType())) @@ -1384,17 +1388,20 @@ struct DiffTransposePass auto pairType = as(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); + auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()); + auto diffZero = builder->emitCallInst( + diffType, + diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), + List()); + // Initialize this var to (arg.primal, 0). builder->emitStore( var, builder->emitMakeDifferentialPair( arg->getDataType(), makePairArg->getPrimalValue(), - builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()), - diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), - List()))); - + diffZero)); + args.add(var); argTypes.add(builder->getInOutType(pairType)); argRequiresLoad.add(true); -- cgit v1.2.3