From 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 6 Jan 2023 13:39:06 -0800 Subject: Split bwd_diff op into separate ops for primal and propagate func. (#2582) * Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He --- source/slang/slang-ir.cpp | 48 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index d8a8fb7c4..9e0e328bd 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -300,6 +300,11 @@ namespace Slang return as(getNextInst()); } + IRParam* IRParam::getPrevParam() + { + return as(getPrevInst()); + } + // IRArrayTypeBase IRInst* IRArrayTypeBase::getElementCount() @@ -2802,6 +2807,15 @@ namespace Slang operands); } + IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( + IRInst* func) + { + return (IRBackwardDiffIntermediateContextType*)getType( + kIROp_BackwardDiffIntermediateContextType, + 1, + &func); + } + IRFuncType* IRBuilder::getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -3129,6 +3143,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst( + this, + kIROp_BackwardDifferentiatePrimal, + type, + baseFn); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst( + this, + kIROp_BackwardDifferentiatePropagate, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); @@ -6622,6 +6658,7 @@ namespace Slang case kIROp_UnpackAnyValue: case kIROp_Reinterpret: case kIROp_GetNativePtr: + case kIROp_BackwardDiffIntermediateContextType: return false; case kIROp_ForwardDifferentiate: @@ -6904,6 +6941,16 @@ namespace Slang } return nullptr; } + + IRInst* getGenericReturnVal(IRInst* inst) + { + if (auto gen = as(inst)) + { + return findGenericReturnVal(gen); + } + return inst; + } + } // namespace Slang #if SLANG_VC @@ -6917,4 +6964,3 @@ SLANG_API const int SlangDebug__IROpStringLit = Slang::kIROp_StringLit; SLANG_API const int SlangDebug__IROpIntLit = Slang::kIROp_IntLit; #endif #endif - -- cgit v1.2.3