From bbd1e1786401bb88c34802b987d4da72e2364503 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 1 Feb 2023 14:18:57 -0800 Subject: Support `out` parameters in backward differentiation. (#2619) * Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2a4ae59a7..4814726cf 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3745,15 +3745,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) { - return emitIntrinsicInst( - diffType, - kIROp_DifferentialPairGetDifferential, - 1, - &diffPair); - } - - IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair) - { + SLANG_ASSERT(as(diffPair->getDataType())); return emitIntrinsicInst( diffType, kIROp_DifferentialPairGetDifferential, @@ -3763,7 +3755,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { - auto valueType = as(diffPair->getDataType())->getValueType(); + auto valueType = cast(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( valueType, kIROp_DifferentialPairGetPrimal, @@ -3771,16 +3763,6 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair) - { - auto valueType = as( - as(diffPair->getDataType())->getValueType())->getValueType(); - return emitIntrinsicInst( - this->getPtrType(kIROp_PtrType, valueType), - kIROp_DifferentialPairGetPrimal, - 1, - &diffPair); - } IRInst* IRBuilder::emitMakeMatrix( IRType* type, @@ -4240,6 +4222,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitLoadReverseGradient(IRType* type, IRInst* diffValue) + { + auto inst = createInst( + this, + kIROp_LoadReverseGradient, + type, + diffValue); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitLoad( IRType* type, IRInst* ptr) @@ -6818,6 +6812,7 @@ namespace Slang case kIROp_MakeTuple: case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads + case kIROp_LoadReverseGradient: case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: -- cgit v1.2.3