summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-19 08:58:20 -0800
committerGitHub <noreply@github.com>2023-01-19 08:58:20 -0800
commit6fae15cd1210d8b664243d640e70ca47dccf9752 (patch)
treed3235149f587ed18147f7a0d916932e199dce888 /source
parent0586f3298fa7d554fa2682103eefba88740d6758 (diff)
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (#2602)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp42
-rw-r--r--source/slang/slang-check-expr.cpp7
-rw-r--r--source/slang/slang-check-impl.h8
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
5 files changed, 92 insertions, 49 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a535ba104..7b5f85b60 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6894,38 +6894,34 @@ namespace Slang
bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func)
{
- // A function is differentiable if it is marked as differentiable, or it
- // has an associated derivative function.
- if (func->findModifier<DifferentiableAttribute>())
- return true;
- for (auto assocDecl : getAssociatedDeclsForDecl(func))
- {
- switch (assocDecl.kind)
- {
- case DeclAssociationKind::ForwardDerivativeFunc:
- case DeclAssociationKind::BackwardDerivativeFunc:
- return true;
- default:
- break;
- }
- }
- return false;
+ return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None;
}
bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func)
{
- // A function is differentiable if it is marked as differentiable, or it
- // has an associated derivative function.
+ return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward;
+ }
+
+ FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func)
+ {
if (func->findModifier<BackwardDifferentiableAttribute>())
- return true;
+ return FunctionDifferentiableLevel::Backward;
if (func->findModifier<BackwardDerivativeAttribute>())
- return true;
+ return FunctionDifferentiableLevel::Backward;
+
+ FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None;
+ if (func->findModifier<DifferentiableAttribute>())
+ diffLevel = FunctionDifferentiableLevel::Forward;
+
for (auto assocDecl : getAssociatedDeclsForDecl(func))
{
switch (assocDecl.kind)
{
case DeclAssociationKind::BackwardDerivativeFunc:
- return true;
+ return FunctionDifferentiableLevel::Backward;
+ case DeclAssociationKind::ForwardDerivativeFunc:
+ diffLevel = FunctionDifferentiableLevel::Forward;
+ break;
default:
break;
}
@@ -6937,12 +6933,12 @@ namespace Slang
case BuiltinRequirementKind::DAddFunc:
case BuiltinRequirementKind::DMulFunc:
case BuiltinRequirementKind::DZeroFunc:
- return true;
+ return FunctionDifferentiableLevel::Backward;
default:
break;
}
}
- return false;
+ return diffLevel;
}
List<ExtensionDecl*> const& getCandidateExtensions(
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2fc18628e..43124b535 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1967,6 +1967,10 @@ namespace Slang
if (m_parentDifferentiableAttr)
{
+ FunctionDifferentiableLevel callerDiffLevel = FunctionDifferentiableLevel::None;
+ if (m_parentFunc)
+ callerDiffLevel = getShared()->getFuncDifferentiableLevel(m_parentFunc);
+
if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
{
// Register types for final resolved invoke arguments again.
@@ -1978,7 +1982,8 @@ namespace Slang
{
if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl()))
{
- if (getShared()->isDifferentiableFunc(calleeDecl))
+ auto calleeDiffLevel = getShared()->getFuncDifferentiableLevel(calleeDecl);
+ if (calleeDiffLevel >= callerDiffLevel)
{
if (!m_treatAsDifferentiableExpr)
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index fb47a38c1..6099febb5 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -11,7 +11,12 @@
namespace Slang
{
-
+ enum class FunctionDifferentiableLevel
+ {
+ None,
+ Forward,
+ Backward
+ };
/// Should the given `decl` be treated as a static rather than instance declaration?
bool isEffectivelyStatic(
Decl* decl);
@@ -292,6 +297,7 @@ namespace Slang
bool isDifferentiableFunc(FunctionDeclBase* func);
bool isBackwardDifferentiableFunc(FunctionDeclBase* func);
+ FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func);
private:
/// Mapping from type declarations to the known extensiosn that apply to them
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 329bf615b..4820c430f 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -576,7 +576,7 @@ DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'vo
DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.")
DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.")
-DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-differentiable function `$0`, use 'no_diff' to clarify intention.")
+DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.")
DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.")
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 8413e7e79..cb7290036 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -12,7 +12,11 @@ public:
DiagnosticSink* sink;
AutoDiffSharedContext sharedContext;
- HashSet<IRInst*> differentiableFunctions;
+ enum DifferentiableLevel
+ {
+ Forward, Backward
+ };
+ Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
: InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
@@ -59,7 +63,7 @@ public:
}
- bool _isDifferentiableFuncImpl(IRInst* func)
+ bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
func = getLeafFunc(func);
if (!func)
@@ -71,32 +75,41 @@ public:
{
case kIROp_ForwardDerivativeDecoration:
case kIROp_ForwardDifferentiableDecoration:
+ if (level == DifferentiableLevel::Forward)
+ return true;
+ break;
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDifferentiableDecoration:
return true;
+ default:
+ break;
}
}
return false;
}
- bool isDifferentiableFunc(IRInst* func)
+ bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
- switch (func->getOp())
+ if (level == DifferentiableLevel::Forward)
{
- case kIROp_ForwardDifferentiate:
- case kIROp_BackwardDifferentiate:
- return true;
- default:
- break;
+ switch (func->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return true;
+ default:
+ break;
+ }
}
- func = getSpecializedVal(func);
+ func = getLeafFunc(func);
if (!func)
return false;
- if (differentiableFunctions.Contains(func))
- return true;
+
+ if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ return *existingLevel >= level;
if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
return true;
@@ -125,7 +138,10 @@ public:
{
if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
{
- return true;
+ if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward)
+ return true;
+ if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward)
+ return true;
}
}
}
@@ -135,7 +151,11 @@ public:
{
if (as<IRGeneric>(func))
{
- return differentiableFunctions.Contains(func);
+ if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ {
+ if (*existingLevel >= level)
+ return true;
+ }
}
}
return false;
@@ -235,6 +255,10 @@ public:
if (differentiableInputs == 0)
sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput);
+ DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward;
+ if (isBackwardDifferentiableFunc(funcInst))
+ requiredDiffLevel = DifferentiableLevel::Backward;
+
auto isInstProducingDiff = [&](IRInst* inst) -> bool
{
switch (inst->getOp())
@@ -242,7 +266,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel);
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
// Just assume it is producing diff.
@@ -310,7 +334,7 @@ public:
switch (inst->getOp())
{
case kIROp_Call:
- if (isDifferentiableFunc(as<IRCall>(inst)->getCallee()))
+ if (isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel))
{
addToExpectDiffWorkList(inst);
}
@@ -349,7 +373,11 @@ public:
{
if (auto call = as<IRCall>(inst))
{
- sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee()));
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getLeafFunc(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
}
}
switch (inst->getOp())
@@ -395,22 +423,30 @@ public:
void processModule()
{
// Collect set of differentiable functions.
- HashSet<UnownedStringSlice> differentiableSymbolNames;
+ HashSet<UnownedStringSlice> fwdDifferentiableSymbolNames, bwdDifferentiableSymbolNames;
for (auto inst : module->getGlobalInsts())
{
- if (_isDifferentiableFuncImpl(inst))
+ if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward))
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst, DifferentiableLevel::Backward);
+ }
+ else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward))
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
- differentiableSymbolNames.Add(linkageDecor->getMangledName());
- differentiableFunctions.Add(inst);
+ fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst, DifferentiableLevel::Forward);
}
}
for (auto inst : module->getGlobalInsts())
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
{
- if (differentiableSymbolNames.Contains(linkageDecor->getMangledName()))
- differentiableFunctions.Add(inst);
+ if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions[inst] = DifferentiableLevel::Backward;
+ else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward);
}
}