From a911ca6e06ce41e403b80fe6054162393491c8ac Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Mar 2023 10:57:28 -0700 Subject: Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695) * Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2819a6d83..08c066f5d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2459,6 +2459,26 @@ namespace Slang if (found) { memoryArena.rewindToCursor(cursor); + + // If the found inst is defined in the same parent as current insert location but + // is located after the insert location, we need to move it to the insert location. + auto foundInst = *found; + if (foundInst->getParent() && foundInst->getParent() == getInsertLoc().getParent() && + getInsertLoc().getMode() == IRInsertLoc::Mode::Before) + { + auto insertLoc = getInsertLoc().getInst(); + bool isAfter = false; + for (auto cur = insertLoc->next; cur; cur = cur->next) + { + if (cur == foundInst) + { + isAfter = true; + break; + } + } + if (isAfter) + foundInst->insertBefore(insertLoc); + } return *found; } } @@ -2779,6 +2799,17 @@ namespace Slang operands); } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPairUserCodeType*)getType( + kIROp_DifferentialPairUserCodeType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( IRInst* func) { @@ -3162,6 +3193,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + + IRInst* args[] = { primal, differential }; + auto inst = createInstWithTrailingArgs( + this, kIROp_MakeDifferentialPairUserCode, type, 2, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitSpecializeInst( IRType* type, IRInst* genericVal, @@ -3751,6 +3794,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferentialUserCode, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimalUserCode(IRInst* diffPair) + { + auto valueType = cast(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPairGetPrimalUserCode, + 1, + &diffPair); + } IRInst* IRBuilder::emitMakeMatrix( IRType* type, -- cgit v1.2.3