summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-04 09:36:23 -0700
committerGitHub <noreply@github.com>2022-11-04 09:36:23 -0700
commitc6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch)
tree6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/slang-ir.cpp
parent015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff)
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp50
1 files changed, 48 insertions, 2 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 083ef98c5..f9686ac5b 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1967,6 +1967,7 @@ namespace Slang
return getStringSlice() == rhs->getStringSlice();
}
case kIROp_VoidLit:
+ case kIROp_DifferentialBottomValue:
{
return true;
}
@@ -2009,6 +2010,7 @@ namespace Slang
return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength()));
}
case kIROp_VoidLit:
+ case kIROp_DifferentialBottomValue:
{
return code;
}
@@ -2074,12 +2076,20 @@ namespace Slang
}
case kIROp_VoidLit:
{
- const size_t instSize = prefixSize;
+ const size_t instSize = prefixSize + sizeof(void*);
irValue = static_cast<IRConstant*>(
_createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
irValue->value.ptrVal = keyInst.value.ptrVal;
break;
}
+ case kIROp_DifferentialBottomValue:
+ {
+ const size_t instSize = prefixSize + sizeof(void*);
+ irValue = static_cast<IRConstant*>(
+ _createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
+ irValue->value.ptrVal = nullptr;
+ break;
+ }
case kIROp_StringLit:
{
const UnownedStringSlice slice = keyInst.getStringSlice();
@@ -2182,6 +2192,17 @@ namespace Slang
return _findOrEmitConstant(keyInst);
}
+ IRInst* IRBuilder::getDifferentialBottom()
+ {
+ IRType* type = getDifferentialBottomType();
+ IRConstant keyInst;
+ memset(&keyInst, 0, sizeof(keyInst));
+ keyInst.m_op = kIROp_DifferentialBottomValue;
+ keyInst.typeUse.usedValue = type;
+ keyInst.value.intVal = 0;
+ return (IRInst*)_findOrEmitConstant(keyInst);
+ }
+
IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice)
{
IRConstant keyInst;
@@ -2564,6 +2585,12 @@ namespace Slang
IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); }
+ IRDifferentialBottomType* IRBuilder::getDifferentialBottomType()
+ {
+ return (IRDifferentialBottomType*)getType(kIROp_DifferentialBottomType);
+ }
+
+
IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes)
{
return (IRAssociatedType*)getType(kIROp_AssociatedType,
@@ -2760,7 +2787,7 @@ namespace Slang
IRDifferentialPairType* IRBuilder::getDifferentialPairType(
IRType* valueType,
- IRWitnessTable* witnessTable)
+ IRInst* witnessTable)
{
IRInst* operands[] = { valueType, witnessTable };
return (IRDifferentialPairType*)getType(
@@ -3389,6 +3416,25 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_makeVector, argCount, args);
}
+ IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
+ {
+ auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
UInt argCount,