summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-06 13:39:06 -0800
committerGitHub <noreply@github.com>2023-01-06 13:39:06 -0800
commit33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch)
tree318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-ir.cpp
parente70cbe76ce74769069b7384f5f05c62da1ca45ed (diff)
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp48
1 files changed, 47 insertions, 1 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index d8a8fb7c4..9e0e328bd 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -300,6 +300,11 @@ namespace Slang
return as<IRParam>(getNextInst());
}
+ IRParam* IRParam::getPrevParam()
+ {
+ return as<IRParam>(getPrevInst());
+ }
+
// IRArrayTypeBase
IRInst* IRArrayTypeBase::getElementCount()
@@ -2802,6 +2807,15 @@ namespace Slang
operands);
}
+ IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType(
+ IRInst* func)
+ {
+ return (IRBackwardDiffIntermediateContextType*)getType(
+ kIROp_BackwardDiffIntermediateContextType,
+ 1,
+ &func);
+ }
+
IRFuncType* IRBuilder::getFuncType(
UInt paramCount,
IRType* const* paramTypes,
@@ -3129,6 +3143,28 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRBackwardDifferentiatePrimal>(
+ this,
+ kIROp_BackwardDifferentiatePrimal,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRBackwardDifferentiatePropagate>(
+ this,
+ kIROp_BackwardDifferentiatePropagate,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
@@ -6622,6 +6658,7 @@ namespace Slang
case kIROp_UnpackAnyValue:
case kIROp_Reinterpret:
case kIROp_GetNativePtr:
+ case kIROp_BackwardDiffIntermediateContextType:
return false;
case kIROp_ForwardDifferentiate:
@@ -6904,6 +6941,16 @@ namespace Slang
}
return nullptr;
}
+
+ IRInst* getGenericReturnVal(IRInst* inst)
+ {
+ if (auto gen = as<IRGeneric>(inst))
+ {
+ return findGenericReturnVal(gen);
+ }
+ return inst;
+ }
+
} // namespace Slang
#if SLANG_VC
@@ -6917,4 +6964,3 @@ SLANG_API const int SlangDebug__IROpStringLit = Slang::kIROp_StringLit;
SLANG_API const int SlangDebug__IROpIntLit = Slang::kIROp_IntLit;
#endif
#endif
-