diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-04 09:36:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-04 09:36:23 -0700 |
| commit | c6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch) | |
| tree | 6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/slang-ir.cpp | |
| parent | 015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff) | |
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 083ef98c5..f9686ac5b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1967,6 +1967,7 @@ namespace Slang return getStringSlice() == rhs->getStringSlice(); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return true; } @@ -2009,6 +2010,7 @@ namespace Slang return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength())); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return code; } @@ -2074,12 +2076,20 @@ namespace Slang } case kIROp_VoidLit: { - const size_t instSize = prefixSize; + const size_t instSize = prefixSize + sizeof(void*); irValue = static_cast<IRConstant*>( _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); irValue->value.ptrVal = keyInst.value.ptrVal; break; } + case kIROp_DifferentialBottomValue: + { + const size_t instSize = prefixSize + sizeof(void*); + irValue = static_cast<IRConstant*>( + _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); + irValue->value.ptrVal = nullptr; + break; + } case kIROp_StringLit: { const UnownedStringSlice slice = keyInst.getStringSlice(); @@ -2182,6 +2192,17 @@ namespace Slang return _findOrEmitConstant(keyInst); } + IRInst* IRBuilder::getDifferentialBottom() + { + IRType* type = getDifferentialBottomType(); + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.m_op = kIROp_DifferentialBottomValue; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = 0; + return (IRInst*)_findOrEmitConstant(keyInst); + } + IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice) { IRConstant keyInst; @@ -2564,6 +2585,12 @@ namespace Slang IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); } + IRDifferentialBottomType* IRBuilder::getDifferentialBottomType() + { + return (IRDifferentialBottomType*)getType(kIROp_DifferentialBottomType); + } + + IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes) { return (IRAssociatedType*)getType(kIROp_AssociatedType, @@ -2760,7 +2787,7 @@ namespace Slang IRDifferentialPairType* IRBuilder::getDifferentialPairType( IRType* valueType, - IRWitnessTable* witnessTable) + IRInst* witnessTable) { IRInst* operands[] = { valueType, witnessTable }; return (IRDifferentialPairType*)getType( @@ -3389,6 +3416,25 @@ namespace Slang return emitIntrinsicInst(type, kIROp_makeVector, argCount, args); } + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + { + auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitMakeMatrix( IRType* type, UInt argCount, |
