summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-16 23:46:14 -0700
committerGitHub <noreply@github.com>2023-03-16 23:46:14 -0700
commit9476d4543f4336a66308e55f722b0b0b2bd69dd2 (patch)
treeff3a0514249f5c3975177bf053c5cb038e37acc8 /source/slang/slang-ir-check-differentiability.cpp
parent77d3630eef4ea1c4b0424a46526a6be476a89230 (diff)
Fix Phi simplification bug. (#2710)
* Fix Phi simplification bug. * Fix up. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix test. * Fix test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 6f97ce076..14178a86c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -194,7 +194,7 @@ public:
return false;
}
- bool instHasNonTrivialDerivative(IRInst* inst)
+ bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst)
{
switch (inst->getOp())
{
@@ -206,7 +206,7 @@ public:
return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
}
default:
- return true;
+ return isDifferentiableType(diffTypeContext, inst->getDataType());
}
}
@@ -468,7 +468,7 @@ public:
if (auto storeInst = as<IRStore>(inst))
{
if (produceDiffSet.Contains(storeInst->getVal()) &&
- instHasNonTrivialDerivative(storeInst->getVal()) &&
+ instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);