summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ir.cpp
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp141
1 files changed, 137 insertions, 4 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6c7691d13..b89929f55 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3022,6 +3022,17 @@ namespace Slang
operands);
}
+ IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType(
+ IRType* valueType,
+ IRInst* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPtrPairType*)getType(
+ kIROp_DifferentialPtrPairType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType(
IRType* valueType,
IRInst* witnessTable)
@@ -3503,7 +3514,7 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
+ IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr);
@@ -3516,6 +3527,98 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)->getValueType() != nullptr);
+
+ IRInst* args[] = {primal, differential};
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPtrPair>(
+ this, kIROp_MakeDifferentialPtrPair, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal)
+ {
+ if (as<IRDifferentialPairType>(pairType))
+ {
+ return emitMakeDifferentialValuePair(pairType, primalVal, diffVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairType))
+ {
+ // Quick optimization:
+ // If primalVal and diffVal are extracted from the same pointer-pair,
+ // we can just use the pointer-pair directly.
+ //
+ if (auto primalPtrVal = as<IRDifferentialPtrPairGetPrimal>(primalVal))
+ {
+ if (auto diffPtrVal = as<IRDifferentialPtrPairGetDifferential>(diffVal))
+ {
+ if (primalPtrVal->getBase() == diffPtrVal->getBase())
+ return primalPtrVal->getBase();
+ }
+ }
+ return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetDifferential(diffType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetDifferential(diffType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(primalType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(primalType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type));
@@ -4222,7 +4325,7 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args);
}
- IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
return emitIntrinsicInst(
@@ -4232,7 +4335,18 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ SLANG_ASSERT(as<IRDifferentialPtrPairType>(diffPair->getDataType()));
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPtrPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair)
{
auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
@@ -4242,7 +4356,7 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair)
{
return emitIntrinsicInst(
primalType,
@@ -4251,6 +4365,25 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair)
+ {
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ primalType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));