diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-15 23:26:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 23:26:14 -0700 |
| commit | f7431f96e1cad2a68534bebc1f25cd6f65f87f82 (patch) | |
| tree | a9171c4705f53f246f8a75d5ced752d6538a448c /source | |
| parent | 3a1a6c47cf7f24b09016e08b7cd7e2863911b08d (diff) | |
Fix `transcribeConstruct` for `makeStruct`. (#2703)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 1 |
2 files changed, 61 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 247c3ddde..e9c156055 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -346,6 +346,64 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* } } +InstPair ForwardDiffTranscriber::transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct) +{ + IRInst* primalMakeStruct = maybeCloneForPrimalInst(builder, origMakeStruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst + // + auto primalStructType = (IRType*)findOrTranscribePrimalInst(builder, origMakeStruct->getDataType()); + if (auto diffStructType = differentiateType(builder, primalStructType)) + { + auto primalStruct = as<IRStructType>(getResolvedInstForDecorations(primalStructType)); + SLANG_RELEASE_ASSERT(primalStruct); + + List<IRInst*> diffOperands; + UIndex ii = 0; + for (auto field : primalStruct->getFields()) + { + SLANG_RELEASE_ASSERT(ii < origMakeStruct->getOperandCount()); + + // If this field is not differentiable, skip the operand. + if (!field->getKey()->findDecoration<IRDerivativeMemberDecoration>()) + { + ii++; + continue; + } + + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origMakeStruct->getOperand(ii), nullptr)) + { + diffOperands.add(diffInst); + } + else + { + auto operandDataType = origMakeStruct->getOperand(ii)->getDataType(); + auto diffOperandType = differentiateType(builder, operandDataType); + SLANG_RELEASE_ASSERT(diffOperandType); + operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } + ii++; + } + + return InstPair( + primalMakeStruct, + builder->emitIntrinsicInst( + diffStructType, + kIROp_MakeStruct, + diffOperands.getCount(), + diffOperands.getBuffer())); + } + else + { + return InstPair(primalMakeStruct, nullptr); + } +} + static bool _isDifferentiableFunc(IRInst* func) { func = getResolvedInstForDecorations(func); @@ -1551,10 +1609,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_IntCast: case kIROp_FloatCast: case kIROp_MakeVectorFromScalar: - case kIROp_MakeStruct: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: return transcribeConstruct(builder, origInst); + case kIROp_MakeStruct: + return transcribeMakeStruct(builder, origInst); case kIROp_LookupWitness: return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 6032c2319..e9774be49 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -46,6 +46,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // and then return nullptr. Literals do not need to be differentiated. // InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct); + InstPair transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct); // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials |
