From c6e6b7a9177bf4f7fc2f05da36c5952979006d78 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 4 Nov 2022 09:36:23 -0700 Subject: Higher order differentiation. (#2487) Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 083ef98c5..f9686ac5b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1967,6 +1967,7 @@ namespace Slang return getStringSlice() == rhs->getStringSlice(); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return true; } @@ -2009,6 +2010,7 @@ namespace Slang return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength())); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return code; } @@ -2074,12 +2076,20 @@ namespace Slang } case kIROp_VoidLit: { - const size_t instSize = prefixSize; + const size_t instSize = prefixSize + sizeof(void*); irValue = static_cast( _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); irValue->value.ptrVal = keyInst.value.ptrVal; break; } + case kIROp_DifferentialBottomValue: + { + const size_t instSize = prefixSize + sizeof(void*); + irValue = static_cast( + _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); + irValue->value.ptrVal = nullptr; + break; + } case kIROp_StringLit: { const UnownedStringSlice slice = keyInst.getStringSlice(); @@ -2182,6 +2192,17 @@ namespace Slang return _findOrEmitConstant(keyInst); } + IRInst* IRBuilder::getDifferentialBottom() + { + IRType* type = getDifferentialBottomType(); + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.m_op = kIROp_DifferentialBottomValue; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = 0; + return (IRInst*)_findOrEmitConstant(keyInst); + } + IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice) { IRConstant keyInst; @@ -2564,6 +2585,12 @@ namespace Slang IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); } + IRDifferentialBottomType* IRBuilder::getDifferentialBottomType() + { + return (IRDifferentialBottomType*)getType(kIROp_DifferentialBottomType); + } + + IRAssociatedType* IRBuilder::getAssociatedType(ArrayView constraintTypes) { return (IRAssociatedType*)getType(kIROp_AssociatedType, @@ -2760,7 +2787,7 @@ namespace Slang IRDifferentialPairType* IRBuilder::getDifferentialPairType( IRType* valueType, - IRWitnessTable* witnessTable) + IRInst* witnessTable) { IRInst* operands[] = { valueType, witnessTable }; return (IRDifferentialPairType*)getType( @@ -3389,6 +3416,25 @@ namespace Slang return emitIntrinsicInst(type, kIROp_makeVector, argCount, args); } + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + { + auto valueType = as(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitMakeMatrix( IRType* type, UInt argCount, -- cgit v1.2.3