summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 6f97ce076..14178a86c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -194,7 +194,7 @@ public:
return false;
}
- bool instHasNonTrivialDerivative(IRInst* inst)
+ bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst)
{
switch (inst->getOp())
{
@@ -206,7 +206,7 @@ public:
return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
}
default:
- return true;
+ return isDifferentiableType(diffTypeContext, inst->getDataType());
}
}
@@ -468,7 +468,7 @@ public:
if (auto storeInst = as<IRStore>(inst))
{
if (produceDiffSet.Contains(storeInst->getVal()) &&
- instHasNonTrivialDerivative(storeInst->getVal()) &&
+ instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);