summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir-check-differentiability.cpp
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
1 files changed, 6 insertions, 9 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 14f6394e2..6f97ce076 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -73,16 +73,13 @@ public:
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
- if (level == DifferentiableLevel::Forward)
+ switch (func->getOp())
{
- switch (func->getOp())
- {
- case kIROp_ForwardDifferentiate:
- case kIROp_BackwardDifferentiate:
- return true;
- default:
- break;
- }
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return isDifferentiableFunc(func->getOperand(0), level);
+ default:
+ break;
}
func = getResolvedInstForDecorations(func);