diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ir.cpp | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 141 |
1 files changed, 137 insertions, 4 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6c7691d13..b89929f55 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3022,6 +3022,17 @@ namespace Slang operands); } + IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPtrPairType*)getType( + kIROp_DifferentialPtrPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( IRType* valueType, IRInst* witnessTable) @@ -3503,7 +3514,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)); SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr); @@ -3516,6 +3527,98 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)); + SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)->getValueType() != nullptr); + + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs<IRMakeDifferentialPtrPair>( + this, kIROp_MakeDifferentialPtrPair, type, 2, args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal) + { + if (as<IRDifferentialPairType>(pairType)) + { + return emitMakeDifferentialValuePair(pairType, primalVal, diffVal); + } + else if (as<IRDifferentialPtrPairType>(pairType)) + { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as<IRDifferentialPtrPairGetPrimal>(primalVal)) + { + if (auto diffPtrVal = as<IRDifferentialPtrPairGetDifferential>(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } + return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetDifferential(diffType, pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(primalType, pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type)); @@ -4222,7 +4325,7 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args); } - IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); return emitIntrinsicInst( @@ -4232,7 +4335,18 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + + IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as<IRDifferentialPtrPairType>(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPtrPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair) { auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( @@ -4242,7 +4356,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair) { return emitIntrinsicInst( primalType, @@ -4251,6 +4365,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair) + { + auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); |
