summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir.cpp
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2819a6d83..08c066f5d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2459,6 +2459,26 @@ namespace Slang
if (found)
{
memoryArena.rewindToCursor(cursor);
+
+ // If the found inst is defined in the same parent as current insert location but
+ // is located after the insert location, we need to move it to the insert location.
+ auto foundInst = *found;
+ if (foundInst->getParent() && foundInst->getParent() == getInsertLoc().getParent() &&
+ getInsertLoc().getMode() == IRInsertLoc::Mode::Before)
+ {
+ auto insertLoc = getInsertLoc().getInst();
+ bool isAfter = false;
+ for (auto cur = insertLoc->next; cur; cur = cur->next)
+ {
+ if (cur == foundInst)
+ {
+ isAfter = true;
+ break;
+ }
+ }
+ if (isAfter)
+ foundInst->insertBefore(insertLoc);
+ }
return *found;
}
}
@@ -2779,6 +2799,17 @@ namespace Slang
operands);
}
+ IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType(
+ IRType* valueType,
+ IRInst* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPairUserCodeType*)getType(
+ kIROp_DifferentialPairUserCodeType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType(
IRInst* func)
{
@@ -3162,6 +3193,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type)->getValueType() != nullptr);
+
+ IRInst* args[] = { primal, differential };
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
+ this, kIROp_MakeDifferentialPairUserCode, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSpecializeInst(
IRType* type,
IRInst* genericVal,
@@ -3751,6 +3794,25 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
+ {
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPairGetDifferentialUserCode,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimalUserCode(IRInst* diffPair)
+ {
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPairGetPrimalUserCode,
+ 1,
+ &diffPair);
+ }
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,