diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f998ae13f..70018b476 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1356,6 +1356,8 @@ struct DiffTransposePass // to inform us this case. Here we just need to generate a load of the derivative variable // and use it as the final argument. args.add(builder->emitLoad(arg->getOperand(0))); + argTypes.add(args.getLast()->getDataType()); + argRequiresLoad.add(false); } else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg)) { @@ -1369,6 +1371,8 @@ struct DiffTransposePass auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); builder->emitStore(tempVar, pairVal); args.add(tempVar); + argTypes.add(builder->getInOutType(pairType)); + argRequiresLoad.add(false); writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); } else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType())) @@ -1384,17 +1388,20 @@ struct DiffTransposePass auto pairType = as<IRDifferentialPairType>(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); + auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()); + auto diffZero = builder->emitCallInst( + diffType, + diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), + List<IRInst*>()); + // Initialize this var to (arg.primal, 0). builder->emitStore( var, builder->emitMakeDifferentialPair( arg->getDataType(), makePairArg->getPrimalValue(), - builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()), - diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), - List<IRInst*>()))); - + diffZero)); + args.add(var); argTypes.add(builder->getInOutType(pairType)); argRequiresLoad.add(true); |
