summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp141
1 files changed, 137 insertions, 4 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6c7691d13..b89929f55 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3022,6 +3022,17 @@ namespace Slang
operands);
}
+ IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType(
+ IRType* valueType,
+ IRInst* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPtrPairType*)getType(
+ kIROp_DifferentialPtrPairType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType(
IRType* valueType,
IRInst* witnessTable)
@@ -3503,7 +3514,7 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
+ IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr);
@@ -3516,6 +3527,98 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)->getValueType() != nullptr);
+
+ IRInst* args[] = {primal, differential};
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPtrPair>(
+ this, kIROp_MakeDifferentialPtrPair, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal)
+ {
+ if (as<IRDifferentialPairType>(pairType))
+ {
+ return emitMakeDifferentialValuePair(pairType, primalVal, diffVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairType))
+ {
+ // Quick optimization:
+ // If primalVal and diffVal are extracted from the same pointer-pair,
+ // we can just use the pointer-pair directly.
+ //
+ if (auto primalPtrVal = as<IRDifferentialPtrPairGetPrimal>(primalVal))
+ {
+ if (auto diffPtrVal = as<IRDifferentialPtrPairGetDifferential>(diffVal))
+ {
+ if (primalPtrVal->getBase() == diffPtrVal->getBase())
+ return primalPtrVal->getBase();
+ }
+ }
+ return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetDifferential(diffType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetDifferential(diffType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(primalType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(primalType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type));
@@ -4222,7 +4325,7 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args);
}
- IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
return emitIntrinsicInst(
@@ -4232,7 +4335,18 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ SLANG_ASSERT(as<IRDifferentialPtrPairType>(diffPair->getDataType()));
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPtrPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair)
{
auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
@@ -4242,7 +4356,7 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair)
{
return emitIntrinsicInst(
primalType,
@@ -4251,6 +4365,25 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair)
+ {
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ primalType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));