diff options
| author | Yong He <yonghe@outlook.com> | 2023-10-06 14:03:18 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-06 14:03:18 -0700 |
| commit | 17c7163c2ae8fc290e70b43d8700b68ef18b1ee1 (patch) | |
| tree | 09df040039fb1221810f956bb83871430cbac47f /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 4547125ce945140dc10542e9606b225dd06159b8 (diff) | |
Small type system fixes. (#3265)
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index b937fe052..ddb70d779 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -85,7 +85,12 @@ public: switch (func->getOp()) { case kIROp_ForwardDifferentiate: + if (auto fwdDerivative = func->getOperand(0)->findDecoration<IRForwardDerivativeDecoration>()) + return isDifferentiableFunc(fwdDerivative->getForwardDerivativeFunc(), level); + return isDifferentiableFunc(func->getOperand(0), level); case kIROp_BackwardDifferentiate: + if (auto bwdDerivative = func->getOperand(0)->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + return isDifferentiableFunc(bwdDerivative->getBackwardDerivativeFunc(), level); return isDifferentiableFunc(func->getOperand(0), level); default: break; |
