From a911ca6e06ce41e403b80fe6054162393491c8ac Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Mar 2023 10:57:28 -0700 Subject: Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695) * Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-transcriber-base.cpp | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index ed122c862..091e7f1ab 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -304,8 +304,16 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), + differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_DifferentialPairUserCodeType: + { + auto primalPairType = as(primalType); + return builder->getDifferentialPairUserCodeType( + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), + differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType)); } case kIROp_FuncType: @@ -634,6 +642,15 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); return makeDiffPair; } + case kIROp_DifferentialPairUserCodeType: + { + auto makeDiffPair = builder->emitMakeDifferentialPairUserCode( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); + return makeDiffPair; + } } if (auto arrayType = as(primalType)) -- cgit v1.2.3