summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-check-differentiability.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp221
1 files changed, 122 insertions, 99 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index cae47fffd..2a1194ebe 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -14,13 +14,15 @@ public:
enum DifferentiableLevel
{
- Forward, Backward
+ Forward,
+ Backward
};
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
: InstPassBase(inModule), sink(inSink), sharedContext(nullptr, inModule->getModuleInst())
- {}
+ {
+ }
bool _isFuncMarkedForAutoDiff(IRInst* func)
{
@@ -32,8 +34,7 @@ public:
switch (decorations->getOp())
{
case kIROp_ForwardDifferentiableDecoration:
- case kIROp_BackwardDifferentiableDecoration:
- return true;
+ case kIROp_BackwardDifferentiableDecoration: return true;
}
}
return false;
@@ -62,10 +63,8 @@ public:
break;
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
- case kIROp_BackwardDifferentiableDecoration:
- return true;
- default:
- break;
+ case kIROp_BackwardDifferentiableDecoration: return true;
+ default: break;
}
}
return false;
@@ -76,7 +75,7 @@ public:
SLANG_ASSERT(as<IRCall>(callInst));
return (
- callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() ||
+ callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() ||
callInst->findDecoration<IRDifferentiableCallDecoration>());
}
@@ -85,15 +84,17 @@ public:
switch (func->getOp())
{
case kIROp_ForwardDifferentiate:
- if (auto fwdDerivative = func->getOperand(0)->findDecoration<IRForwardDerivativeDecoration>())
+ if (auto fwdDerivative =
+ func->getOperand(0)->findDecoration<IRForwardDerivativeDecoration>())
return isDifferentiableFunc(fwdDerivative->getForwardDerivativeFunc(), level);
return isDifferentiableFunc(func->getOperand(0), level);
case kIROp_BackwardDifferentiate:
- if (auto bwdDerivative = func->getOperand(0)->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ if (auto bwdDerivative =
+ func->getOperand(0)
+ ->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
return isDifferentiableFunc(bwdDerivative->getBackwardDerivativeFunc(), level);
return isDifferentiableFunc(func->getOperand(0), level);
- default:
- break;
+ default: break;
}
func = getResolvedInstForDecorations(func);
@@ -126,11 +127,14 @@ public:
return false;
if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>())
return true;
- if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType)
+ if (sharedContext.differentiableInterfaceType &&
+ interfaceType == sharedContext.differentiableInterfaceType)
return true;
- if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>())
+ if (lookupInterfaceMethod->getRequirementKey()
+ ->findDecoration<IRBackwardDerivativeDecoration>())
return true;
- if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRForwardDerivativeDecoration>())
+ if (lookupInterfaceMethod->getRequirementKey()
+ ->findDecoration<IRForwardDerivativeDecoration>())
return level == DifferentiableLevel::Forward;
}
@@ -159,7 +163,9 @@ public:
return false;
}
- bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr)
+ bool canAddressHoldDerivative(
+ DifferentiableTypeConformanceContext& diffTypeContext,
+ IRInst* addr)
{
if (!addr)
return false;
@@ -169,41 +175,43 @@ public:
switch (addr->getOp())
{
case kIROp_Var:
- case kIROp_Param:
- return isDifferentiableType(diffTypeContext, addr->getDataType());
+ case kIROp_Param: return isDifferentiableType(diffTypeContext, addr->getDataType());
case kIROp_FieldAddress:
if (!as<IRFieldAddress>(addr)->getField() ||
as<IRFieldAddress>(addr)
- ->getField()
- ->findDecoration<IRDerivativeMemberDecoration>() == nullptr)
+ ->getField()
+ ->findDecoration<IRDerivativeMemberDecoration>() == nullptr)
return false;
addr = as<IRFieldAddress>(addr)->getBase();
break;
case kIROp_GetElementPtr:
- if (!isDifferentiableType(diffTypeContext, as<IRGetElementPtr>(addr)->getBase()->getDataType()))
+ if (!isDifferentiableType(
+ diffTypeContext,
+ as<IRGetElementPtr>(addr)->getBase()->getDataType()))
return false;
addr = as<IRGetElementPtr>(addr)->getBase();
break;
- default:
- return false;
+ default: return false;
}
}
return false;
}
- bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst)
+ bool instHasNonTrivialDerivative(
+ DifferentiableTypeConformanceContext& diffTypeContext,
+ IRInst* inst)
{
switch (inst->getOp())
{
- case kIROp_DetachDerivative:
- return false;
+ case kIROp_DetachDerivative: return false;
case kIROp_Call:
- {
- auto call = as<IRCall>(inst);
- return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
- }
- default:
- return isDifferentiableType(diffTypeContext, inst->getDataType());
+ {
+ auto call = as<IRCall>(inst);
+ return isDifferentiableFunc(
+ call->getCallee(),
+ CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
+ }
+ default: return isDifferentiableType(diffTypeContext, inst->getDataType());
}
}
@@ -246,7 +254,6 @@ public:
sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc);
return;
}
-
}
}
}
@@ -267,12 +274,11 @@ public:
// data flow analysis.
// `produceDiffSet` represents a set of insts that can provide a diff. This is conservative
// on the positive side: a float literal is considered to be able to provide a diff.
- // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This is
- // conservative on the negative side: if the inst does not provide a diff, or if we can prove the diff
- // is zero, we exclude the inst from the set. This makes `carryNonTrivialDiffSet` a strict subset of
- // `produceDiffSet`.
- // `expectDiffSet` is a set of insts that expects their operands to produce a diff. It is an error
- // if they don't.
+ // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This
+ // is conservative on the negative side: if the inst does not provide a diff, or if we can
+ // prove the diff is zero, we exclude the inst from the set. This makes
+ // `carryNonTrivialDiffSet` a strict subset of `produceDiffSet`. `expectDiffSet` is a set of
+ // insts that expects their operands to produce a diff. It is an error if they don't.
InstHashSet produceDiffSet(funcInst->getModule());
InstHashSet expectDiffSet(funcInst->getModule());
InstHashSet carryNonTrivialDiffSet(funcInst->getModule());
@@ -302,15 +308,17 @@ public:
{
switch (inst->getOp())
{
- case kIROp_FloatLit:
- return true;
+ case kIROp_FloatLit: return true;
case kIROp_Call:
return shouldTreatCallAsDifferentiable(inst) ||
- isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType());
+ isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
+ isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
- // We don't have more knowledge on whether diff is available at the destination address.
- // Just assume it is producing diff if the dest address can hold a derivative.
- //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
+ // We don't have more knowledge on whether diff is available at the destination
+ // address. Just assume it is producing diff if the dest address can hold a
+ // derivative.
+ // TODO: propagate the info if this is a load of a temporary variable intended
+ // to receive result from an `out` parameter.
return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr());
default:
// default case is to assume the inst produces a diff value if any
@@ -332,17 +340,18 @@ public:
{
switch (inst->getOp())
{
- case kIROp_DetachDerivative:
- return false;
+ case kIROp_DetachDerivative: return false;
case kIROp_Call:
if (shouldTreatCallAsDifferentiable(inst))
return false;
return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
- // We don't have more knowledge on whether diff is available at the destination address.
- // Just assume it is producing diff if the dest address can hold a derivative.
- //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
+ // We don't have more knowledge on whether diff is available at the destination
+ // address. Just assume it is producing diff if the dest address can hold a
+ // derivative.
+ // TODO: propagate the info if this is a load of a temporary variable intended
+ // to receive result from an `out` parameter.
return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr());
default:
// default case is to assume the inst produces a diff value if any
@@ -387,7 +396,8 @@ public:
{
for (auto p : block->getPredecessors())
{
- // A Phi Node is producing diff if any of its candidate values are producing diff.
+ // A Phi Node is producing diff if any of its candidate values are
+ // producing diff.
if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
{
if (branch->getArgCount() > paramIndex)
@@ -418,15 +428,17 @@ public:
}
break;
case kIROp_Store:
- {
- auto storeInst = as<IRStore>(inst);
- if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) &&
- isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
{
- addToExpectDiffWorkList(storeInst->getVal());
+ auto storeInst = as<IRStore>(inst);
+ if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) &&
+ isDifferentiableType(
+ diffTypeContext,
+ as<IRStore>(inst)->getPtr()->getDataType()))
+ {
+ addToExpectDiffWorkList(storeInst->getVal());
+ }
}
- }
- break;
+ break;
case kIROp_Return:
if (auto returnVal = as<IRReturn>(inst)->getVal())
{
@@ -437,8 +449,7 @@ public:
}
}
break;
- default:
- break;
+ default: break;
}
}
}
@@ -456,64 +467,69 @@ public:
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
- if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
+ if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
!isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
{
sink->diagnose(
inst,
Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
getResolvedInstForDecorations(call->getCallee()),
- requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
+ : "backward");
}
}
}
switch (inst->getOp())
{
case kIROp_Param:
- {
- auto block = as<IRBlock>(inst->getParent());
- if (block != funcInst->getFirstBlock())
{
- auto paramIndex = getParamIndexInBlock(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst));
- if (paramIndex != -1)
+ auto block = as<IRBlock>(inst->getParent());
+ if (block != funcInst->getFirstBlock())
{
- for (auto p : block->getPredecessors())
+ auto paramIndex = getParamIndexInBlock(
+ as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst));
+ if (paramIndex != -1)
{
- // A Phi Node is producing diff if any of its candidate values are producing diff.
- if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
+ for (auto p : block->getPredecessors())
{
- if (branch->getArgCount() > (UInt)paramIndex)
+ // A Phi Node is producing diff if any of its candidate values are
+ // producing diff.
+ if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
{
- auto arg = branch->getArg(paramIndex);
- addToExpectDiffWorkList(arg);
+ if (branch->getArgCount() > (UInt)paramIndex)
+ {
+ auto arg = branch->getArg(paramIndex);
+ addToExpectDiffWorkList(arg);
+ }
}
}
}
}
+ break;
}
- break;
- }
case kIROp_Call:
- {
- auto callInst = as<IRCall>(inst);
- if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>())
- continue;
- auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
- if (!calleeFuncType) continue;
- if (calleeFuncType->getParamCount() != callInst->getArgCount())
- continue;
- for (UInt a = 0; a < callInst->getArgCount(); a++)
{
- auto arg = callInst->getArg(a);
- auto paramType = calleeFuncType->getParamType(a);
- if (!isDifferentiableType(diffTypeContext, paramType))
+ auto callInst = as<IRCall>(inst);
+ if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>())
+ continue;
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType)
+ continue;
+ if (calleeFuncType->getParamCount() != callInst->getArgCount())
continue;
- addToExpectDiffWorkList(arg);
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ auto paramType = calleeFuncType->getParamType(a);
+ if (!isDifferentiableType(diffTypeContext, paramType))
+ continue;
+ addToExpectDiffWorkList(arg);
+ }
+ break;
}
- break;
- }
default:
- // Default behavior is to request all differentiable operands to provide differential.
+ // Default behavior is to request all differentiable operands to provide
+ // differential.
for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++)
{
auto operand = inst->getOperand(opIndex);
@@ -542,7 +558,8 @@ public:
}
if (!hasBackEdge)
continue;
- if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>())
+ if (loop->findDecoration<IRLoopMaxItersDecoration>() ||
+ loop->findDecoration<IRForceUnrollDecoration>())
{
// We are good.
}
@@ -552,9 +569,9 @@ public:
}
}
- // Make sure all stores of differentiable values are into addresses that can hold derivatives.
- // If we are assigning a value to a non-differentiable location, we need to make sure
- // that value doesn't carray a non-zero diff.
+ // Make sure all stores of differentiable values are into addresses that can hold
+ // derivatives. If we are assigning a value to a non-differentiable location, we need to
+ // make sure that value doesn't carray a non-zero diff.
for (auto block : funcInst->getBlocks())
{
for (auto inst : block->getChildren())
@@ -564,7 +581,9 @@ public:
if (carryNonTrivialDiffSet.contains(storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
- sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
+ sink->diagnose(
+ storeInst->sourceLoc,
+ Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
}
}
else if (auto callInst = as<IRCall>(inst))
@@ -586,7 +605,10 @@ public:
{
if (!canAddressHoldDerivative(diffTypeContext, arg))
{
- sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg);
+ sink->diagnose(
+ arg->sourceLoc,
+ Diagnostics::
+ lossOfDerivativeUsingNonDifferentiableLocationAsOutArg);
}
}
}
@@ -632,7 +654,8 @@ public:
{
if (auto genericInst = as<IRGeneric>(inst))
{
- if (auto innerFunc = as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst)))
+ if (auto innerFunc =
+ as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst)))
processFunc(innerFunc);
}
else if (auto funcInst = as<IRGlobalValueWithCode>(inst))