From ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 19 Sep 2024 03:10:28 -0400 Subject: Support `IDifferentiablePtrType` (#5031) * initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 141 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6c7691d13..b89929f55 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3022,6 +3022,17 @@ namespace Slang operands); } + IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPtrPairType*)getType( + kIROp_DifferentialPtrPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( IRType* valueType, IRInst* witnessTable) @@ -3503,7 +3514,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); @@ -3516,6 +3527,98 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs( + this, kIROp_MakeDifferentialPtrPair, type, 2, args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal) + { + if (as(pairType)) + { + return emitMakeDifferentialValuePair(pairType, primalVal, diffVal); + } + else if (as(pairType)) + { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as(primalVal)) + { + if (auto diffPtrVal = as(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } + return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetDifferential(diffType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(primalType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); @@ -4222,7 +4325,7 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args); } - IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); return emitIntrinsicInst( @@ -4232,7 +4335,18 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + + IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPtrPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair) { auto valueType = cast(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( @@ -4242,7 +4356,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair) { return emitIntrinsicInst( primalType, @@ -4251,6 +4365,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair) + { + auto valueType = cast(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); -- cgit v1.2.3