diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-06 13:39:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-06 13:39:06 -0800 |
| commit | 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch) | |
| tree | 318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-lower-to-ir.cpp | |
| parent | e70cbe76ce74769069b7384f5f05c62da1ca45ed (diff) | |
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func.
* Fix.
* Download swiftshader with github actions instead of curl on linux.
* Fix github action.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 53 |
1 files changed, 50 insertions, 3 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a84cf9b8d..6803e1cb4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1407,6 +1407,33 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(diff); } + LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + + LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + + LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val); + return LoweredValInfo::simple(diff); + } + LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) { return LoweredValInfo(); @@ -6816,9 +6843,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> context->irBuilder->addDecoration( interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration); } - auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl) - ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem - : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + if (as<BackwardDerivativeRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem; + } IRInst* args[] = {originalKey, associatedKey}; auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args); assoc->insertAtEnd(decor); @@ -8405,6 +8446,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl) + { + SLANG_UNUSED(decl); + return LoweredValInfo(getBuilder()->getTypeKind()); + } + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) { // A function declaration may have multiple, target-specific |
