diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-21 12:51:46 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-21 12:51:46 -0800 |
| commit | 0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a (patch) | |
| tree | fd8f2b6e528e01a90cc2f34b2fe8ebf6cc5f97a9 /source | |
| parent | 6bca0ec355aae2955c7de1cd16c2dc0dfe46f19c (diff) | |
Fix transposeCall. (#2669)
* Modify control-flow test case
* Update reverse-control-flow-3.slang
* Fix `transposeCall`.
* Fix.
---------
Co-authored-by: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f998ae13f..70018b476 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1356,6 +1356,8 @@ struct DiffTransposePass // to inform us this case. Here we just need to generate a load of the derivative variable // and use it as the final argument. args.add(builder->emitLoad(arg->getOperand(0))); + argTypes.add(args.getLast()->getDataType()); + argRequiresLoad.add(false); } else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg)) { @@ -1369,6 +1371,8 @@ struct DiffTransposePass auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); builder->emitStore(tempVar, pairVal); args.add(tempVar); + argTypes.add(builder->getInOutType(pairType)); + argRequiresLoad.add(false); writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); } else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType())) @@ -1384,17 +1388,20 @@ struct DiffTransposePass auto pairType = as<IRDifferentialPairType>(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); + auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()); + auto diffZero = builder->emitCallInst( + diffType, + diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), + List<IRInst*>()); + // Initialize this var to (arg.primal, 0). builder->emitStore( var, builder->emitMakeDifferentialPair( arg->getDataType(), makePairArg->getPrimalValue(), - builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()), - diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), - List<IRInst*>()))); - + diffZero)); + args.add(var); argTypes.add(builder->getInOutType(pairType)); argRequiresLoad.add(true); |
