summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-06 13:39:06 -0800
committerGitHub <noreply@github.com>2023-01-06 13:39:06 -0800
commit33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch)
tree318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-lower-to-ir.cpp
parente70cbe76ce74769069b7384f5f05c62da1ca45ed (diff)
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp53
1 files changed, 50 insertions, 3 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a84cf9b8d..6803e1cb4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1407,6 +1407,33 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(diff);
}
+ LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
+ LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
+ LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*)
{
return LoweredValInfo();
@@ -6816,9 +6843,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
context->irBuilder->addDecoration(
interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration);
}
- auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl)
- ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem
- : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ if (as<BackwardDerivativeRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem;
+ }
IRInst* args[] = {originalKey, associatedKey};
auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args);
assoc->insertAtEnd(decor);
@@ -8405,6 +8446,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl)
+ {
+ SLANG_UNUSED(decl);
+ return LoweredValInfo(getBuilder()->getTypeKind());
+ }
+
LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
{
// A function declaration may have multiple, target-specific