From f7431f96e1cad2a68534bebc1f25cd6f65f87f82 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 15 Mar 2023 23:26:14 -0700 Subject: Fix `transcribeConstruct` for `makeStruct`. (#2703) Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-fwd.cpp | 61 +++++++++++++++++++++++++++++++++- source/slang/slang-ir-autodiff-fwd.h | 1 + 2 files changed, 61 insertions(+), 1 deletion(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 247c3ddde..e9c156055 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -346,6 +346,64 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* } } +InstPair ForwardDiffTranscriber::transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct) +{ + IRInst* primalMakeStruct = maybeCloneForPrimalInst(builder, origMakeStruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst + // + auto primalStructType = (IRType*)findOrTranscribePrimalInst(builder, origMakeStruct->getDataType()); + if (auto diffStructType = differentiateType(builder, primalStructType)) + { + auto primalStruct = as(getResolvedInstForDecorations(primalStructType)); + SLANG_RELEASE_ASSERT(primalStruct); + + List diffOperands; + UIndex ii = 0; + for (auto field : primalStruct->getFields()) + { + SLANG_RELEASE_ASSERT(ii < origMakeStruct->getOperandCount()); + + // If this field is not differentiable, skip the operand. + if (!field->getKey()->findDecoration()) + { + ii++; + continue; + } + + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origMakeStruct->getOperand(ii), nullptr)) + { + diffOperands.add(diffInst); + } + else + { + auto operandDataType = origMakeStruct->getOperand(ii)->getDataType(); + auto diffOperandType = differentiateType(builder, operandDataType); + SLANG_RELEASE_ASSERT(diffOperandType); + operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } + ii++; + } + + return InstPair( + primalMakeStruct, + builder->emitIntrinsicInst( + diffStructType, + kIROp_MakeStruct, + diffOperands.getCount(), + diffOperands.getBuffer())); + } + else + { + return InstPair(primalMakeStruct, nullptr); + } +} + static bool _isDifferentiableFunc(IRInst* func) { func = getResolvedInstForDecorations(func); @@ -1551,10 +1609,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_IntCast: case kIROp_FloatCast: case kIROp_MakeVectorFromScalar: - case kIROp_MakeStruct: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: return transcribeConstruct(builder, origInst); + case kIROp_MakeStruct: + return transcribeMakeStruct(builder, origInst); case kIROp_LookupWitness: return transcribeLookupInterfaceMethod(builder, as(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 6032c2319..e9774be49 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -46,6 +46,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // and then return nullptr. Literals do not need to be differentiated. // InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct); + InstPair transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct); // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials -- cgit v1.2.3