From 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 3 Feb 2023 16:44:33 -0800 Subject: Overhaul `transposeParameterBlock` to support `inout` params. (#2621) * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * Add test coverage on inout param. * Fix language server hinting for transcribed mutable params. --------- Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 4814726cf..558574bf6 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4234,6 +4234,50 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar) + { + auto inst = createInst( + this, + kIROp_ReverseGradientDiffPairRef, + type, + primalVar, + diffVar); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitPrimalParamRef(IRInst* param) + { + auto type = param->getFullType(); + auto ptrType = as(type); + auto valueType = type; + if (ptrType) valueType = ptrType->getValueType(); + auto pairType = as(valueType); + IRType* finalType = pairType->getValueType(); + if (ptrType) finalType = getPtrType(ptrType->getOp(), finalType); + auto inst = createInst( + this, + kIROp_PrimalParamRef, + finalType, + param); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitDiffParamRef(IRType* type, IRInst* param) + { + auto inst = createInst( + this, + kIROp_DiffParamRef, + type, + param); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitLoad( IRType* type, IRInst* ptr) @@ -6753,6 +6797,8 @@ namespace Slang // common subexpression elimination, etc. // auto call = cast(this); + if (call->findDecoration()) + return false; return !isPureFunctionalCall(call); } break; @@ -6809,10 +6855,14 @@ namespace Slang case kIROp_MakeOptionalNone: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + case kIROp_MakeDifferentialPair: case kIROp_MakeTuple: case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_LoadReverseGradient: + case kIROp_ReverseGradientDiffPairRef: case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: -- cgit v1.2.3