summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-ir-check-differentiability.cpp
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
1 files changed, 13 insertions, 2 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 186b0cc03..14f6394e2 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -39,12 +39,17 @@ public:
return false;
}
-
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
func = getResolvedInstForDecorations(func);
if (!func)
return false;
+ if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc());
+ if (!func)
+ return false;
+ }
for (auto decorations : func->getDecorations())
{
@@ -84,7 +89,13 @@ public:
if (!func)
return false;
-
+ if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc());
+ if (!func)
+ return false;
+ }
+
if (auto existingLevel = differentiableFunctions.TryGetValue(func))
return *existingLevel >= level;