summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-25 14:48:01 -0800
committerGitHub <noreply@github.com>2023-01-25 14:48:01 -0800
commitaa6814be1f7dea20597ae34d477e79e53d4a543f (patch)
tree15b8ad69e2c4169e12a0ad6e970fe511daa4beb7 /source/slang/slang-ir-check-differentiability.cpp
parentae11538f5d667b11d3b3191a827093f3727eed1b (diff)
Cleanup IR representation of interface member derivative. (#2610)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp20
1 files changed, 4 insertions, 16 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index f8d70c8ed..8cefa6a04 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -131,22 +131,10 @@ public:
return true;
if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType)
return true;
- auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!dictDecor)
- return false;
- for (auto child : dictDecor->getChildren())
- {
- if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
- {
- if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
- {
- if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward)
- return true;
- if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward)
- return true;
- }
- }
- }
+ if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>())
+ return true;
+ if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRForwardDerivativeDecoration>())
+ return level == DifferentiableLevel::Forward;
}
for (; func; func = func->parent)