diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-06 13:39:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-06 13:39:06 -0800 |
| commit | 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch) | |
| tree | 318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-ir.cpp | |
| parent | e70cbe76ce74769069b7384f5f05c62da1ca45ed (diff) | |
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func.
* Fix.
* Download swiftshader with github actions instead of curl on linux.
* Fix github action.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index d8a8fb7c4..9e0e328bd 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -300,6 +300,11 @@ namespace Slang return as<IRParam>(getNextInst()); } + IRParam* IRParam::getPrevParam() + { + return as<IRParam>(getPrevInst()); + } + // IRArrayTypeBase IRInst* IRArrayTypeBase::getElementCount() @@ -2802,6 +2807,15 @@ namespace Slang operands); } + IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( + IRInst* func) + { + return (IRBackwardDiffIntermediateContextType*)getType( + kIROp_BackwardDiffIntermediateContextType, + 1, + &func); + } + IRFuncType* IRBuilder::getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -3129,6 +3143,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRBackwardDifferentiatePrimal>( + this, + kIROp_BackwardDifferentiatePrimal, + type, + baseFn); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRBackwardDifferentiatePropagate>( + this, + kIROp_BackwardDifferentiatePropagate, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)); @@ -6622,6 +6658,7 @@ namespace Slang case kIROp_UnpackAnyValue: case kIROp_Reinterpret: case kIROp_GetNativePtr: + case kIROp_BackwardDiffIntermediateContextType: return false; case kIROp_ForwardDifferentiate: @@ -6904,6 +6941,16 @@ namespace Slang } return nullptr; } + + IRInst* getGenericReturnVal(IRInst* inst) + { + if (auto gen = as<IRGeneric>(inst)) + { + return findGenericReturnVal(gen); + } + return inst; + } + } // namespace Slang #if SLANG_VC @@ -6917,4 +6964,3 @@ SLANG_API const int SlangDebug__IROpStringLit = Slang::kIROp_StringLit; SLANG_API const int SlangDebug__IROpIntLit = Slang::kIROp_IntLit; #endif #endif - |
