diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 6f97ce076..14178a86c 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -194,7 +194,7 @@ public: return false; } - bool instHasNonTrivialDerivative(IRInst* inst) + bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst) { switch (inst->getOp()) { @@ -206,7 +206,7 @@ public: return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); } default: - return true; + return isDifferentiableType(diffTypeContext, inst->getDataType()); } } @@ -468,7 +468,7 @@ public: if (auto storeInst = as<IRStore>(inst)) { if (produceDiffSet.Contains(storeInst->getVal()) && - instHasNonTrivialDerivative(storeInst->getVal()) && + instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); |
