diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-check-differentiability.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 221 |
1 files changed, 122 insertions, 99 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index cae47fffd..2a1194ebe 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -14,13 +14,15 @@ public: enum DifferentiableLevel { - Forward, Backward + Forward, + Backward }; Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions; CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) : InstPassBase(inModule), sink(inSink), sharedContext(nullptr, inModule->getModuleInst()) - {} + { + } bool _isFuncMarkedForAutoDiff(IRInst* func) { @@ -32,8 +34,7 @@ public: switch (decorations->getOp()) { case kIROp_ForwardDifferentiableDecoration: - case kIROp_BackwardDifferentiableDecoration: - return true; + case kIROp_BackwardDifferentiableDecoration: return true; } } return false; @@ -62,10 +63,8 @@ public: break; case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_BackwardDerivativeDecoration: - case kIROp_BackwardDifferentiableDecoration: - return true; - default: - break; + case kIROp_BackwardDifferentiableDecoration: return true; + default: break; } } return false; @@ -76,7 +75,7 @@ public: SLANG_ASSERT(as<IRCall>(callInst)); return ( - callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() || + callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() || callInst->findDecoration<IRDifferentiableCallDecoration>()); } @@ -85,15 +84,17 @@ public: switch (func->getOp()) { case kIROp_ForwardDifferentiate: - if (auto fwdDerivative = func->getOperand(0)->findDecoration<IRForwardDerivativeDecoration>()) + if (auto fwdDerivative = + func->getOperand(0)->findDecoration<IRForwardDerivativeDecoration>()) return isDifferentiableFunc(fwdDerivative->getForwardDerivativeFunc(), level); return isDifferentiableFunc(func->getOperand(0), level); case kIROp_BackwardDifferentiate: - if (auto bwdDerivative = func->getOperand(0)->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + if (auto bwdDerivative = + func->getOperand(0) + ->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) return isDifferentiableFunc(bwdDerivative->getBackwardDerivativeFunc(), level); return isDifferentiableFunc(func->getOperand(0), level); - default: - break; + default: break; } func = getResolvedInstForDecorations(func); @@ -126,11 +127,14 @@ public: return false; if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>()) return true; - if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) + if (sharedContext.differentiableInterfaceType && + interfaceType == sharedContext.differentiableInterfaceType) return true; - if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>()) + if (lookupInterfaceMethod->getRequirementKey() + ->findDecoration<IRBackwardDerivativeDecoration>()) return true; - if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRForwardDerivativeDecoration>()) + if (lookupInterfaceMethod->getRequirementKey() + ->findDecoration<IRForwardDerivativeDecoration>()) return level == DifferentiableLevel::Forward; } @@ -159,7 +163,9 @@ public: return false; } - bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr) + bool canAddressHoldDerivative( + DifferentiableTypeConformanceContext& diffTypeContext, + IRInst* addr) { if (!addr) return false; @@ -169,41 +175,43 @@ public: switch (addr->getOp()) { case kIROp_Var: - case kIROp_Param: - return isDifferentiableType(diffTypeContext, addr->getDataType()); + case kIROp_Param: return isDifferentiableType(diffTypeContext, addr->getDataType()); case kIROp_FieldAddress: if (!as<IRFieldAddress>(addr)->getField() || as<IRFieldAddress>(addr) - ->getField() - ->findDecoration<IRDerivativeMemberDecoration>() == nullptr) + ->getField() + ->findDecoration<IRDerivativeMemberDecoration>() == nullptr) return false; addr = as<IRFieldAddress>(addr)->getBase(); break; case kIROp_GetElementPtr: - if (!isDifferentiableType(diffTypeContext, as<IRGetElementPtr>(addr)->getBase()->getDataType())) + if (!isDifferentiableType( + diffTypeContext, + as<IRGetElementPtr>(addr)->getBase()->getDataType())) return false; addr = as<IRGetElementPtr>(addr)->getBase(); break; - default: - return false; + default: return false; } } return false; } - bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst) + bool instHasNonTrivialDerivative( + DifferentiableTypeConformanceContext& diffTypeContext, + IRInst* inst) { switch (inst->getOp()) { - case kIROp_DetachDerivative: - return false; + case kIROp_DetachDerivative: return false; case kIROp_Call: - { - auto call = as<IRCall>(inst); - return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); - } - default: - return isDifferentiableType(diffTypeContext, inst->getDataType()); + { + auto call = as<IRCall>(inst); + return isDifferentiableFunc( + call->getCallee(), + CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); + } + default: return isDifferentiableType(diffTypeContext, inst->getDataType()); } } @@ -246,7 +254,6 @@ public: sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc); return; } - } } } @@ -267,12 +274,11 @@ public: // data flow analysis. // `produceDiffSet` represents a set of insts that can provide a diff. This is conservative // on the positive side: a float literal is considered to be able to provide a diff. - // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This is - // conservative on the negative side: if the inst does not provide a diff, or if we can prove the diff - // is zero, we exclude the inst from the set. This makes `carryNonTrivialDiffSet` a strict subset of - // `produceDiffSet`. - // `expectDiffSet` is a set of insts that expects their operands to produce a diff. It is an error - // if they don't. + // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This + // is conservative on the negative side: if the inst does not provide a diff, or if we can + // prove the diff is zero, we exclude the inst from the set. This makes + // `carryNonTrivialDiffSet` a strict subset of `produceDiffSet`. `expectDiffSet` is a set of + // insts that expects their operands to produce a diff. It is an error if they don't. InstHashSet produceDiffSet(funcInst->getModule()); InstHashSet expectDiffSet(funcInst->getModule()); InstHashSet carryNonTrivialDiffSet(funcInst->getModule()); @@ -302,15 +308,17 @@ public: { switch (inst->getOp()) { - case kIROp_FloatLit: - return true; + case kIROp_FloatLit: return true; case kIROp_Call: return shouldTreatCallAsDifferentiable(inst) || - isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); + isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && + isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: - // We don't have more knowledge on whether diff is available at the destination address. - // Just assume it is producing diff if the dest address can hold a derivative. - //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. + // We don't have more knowledge on whether diff is available at the destination + // address. Just assume it is producing diff if the dest address can hold a + // derivative. + // TODO: propagate the info if this is a load of a temporary variable intended + // to receive result from an `out` parameter. return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr()); default: // default case is to assume the inst produces a diff value if any @@ -332,17 +340,18 @@ public: { switch (inst->getOp()) { - case kIROp_DetachDerivative: - return false; + case kIROp_DetachDerivative: return false; case kIROp_Call: if (shouldTreatCallAsDifferentiable(inst)) return false; return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: - // We don't have more knowledge on whether diff is available at the destination address. - // Just assume it is producing diff if the dest address can hold a derivative. - //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. + // We don't have more knowledge on whether diff is available at the destination + // address. Just assume it is producing diff if the dest address can hold a + // derivative. + // TODO: propagate the info if this is a load of a temporary variable intended + // to receive result from an `out` parameter. return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr()); default: // default case is to assume the inst produces a diff value if any @@ -387,7 +396,8 @@ public: { for (auto p : block->getPredecessors()) { - // A Phi Node is producing diff if any of its candidate values are producing diff. + // A Phi Node is producing diff if any of its candidate values are + // producing diff. if (auto branch = as<IRUnconditionalBranch>(p->getTerminator())) { if (branch->getArgCount() > paramIndex) @@ -418,15 +428,17 @@ public: } break; case kIROp_Store: - { - auto storeInst = as<IRStore>(inst); - if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) && - isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) { - addToExpectDiffWorkList(storeInst->getVal()); + auto storeInst = as<IRStore>(inst); + if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) && + isDifferentiableType( + diffTypeContext, + as<IRStore>(inst)->getPtr()->getDataType())) + { + addToExpectDiffWorkList(storeInst->getVal()); + } } - } - break; + break; case kIROp_Return: if (auto returnVal = as<IRReturn>(inst)->getVal()) { @@ -437,8 +449,7 @@ public: } } break; - default: - break; + default: break; } } } @@ -456,64 +467,69 @@ public: // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. - if (isDifferentiableType(diffTypeContext, inst->getFullType()) && + if (isDifferentiableType(diffTypeContext, inst->getFullType()) && !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) { sink->diagnose( inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" + : "backward"); } } } switch (inst->getOp()) { case kIROp_Param: - { - auto block = as<IRBlock>(inst->getParent()); - if (block != funcInst->getFirstBlock()) { - auto paramIndex = getParamIndexInBlock(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)); - if (paramIndex != -1) + auto block = as<IRBlock>(inst->getParent()); + if (block != funcInst->getFirstBlock()) { - for (auto p : block->getPredecessors()) + auto paramIndex = getParamIndexInBlock( + as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)); + if (paramIndex != -1) { - // A Phi Node is producing diff if any of its candidate values are producing diff. - if (auto branch = as<IRUnconditionalBranch>(p->getTerminator())) + for (auto p : block->getPredecessors()) { - if (branch->getArgCount() > (UInt)paramIndex) + // A Phi Node is producing diff if any of its candidate values are + // producing diff. + if (auto branch = as<IRUnconditionalBranch>(p->getTerminator())) { - auto arg = branch->getArg(paramIndex); - addToExpectDiffWorkList(arg); + if (branch->getArgCount() > (UInt)paramIndex) + { + auto arg = branch->getArg(paramIndex); + addToExpectDiffWorkList(arg); + } } } } } + break; } - break; - } case kIROp_Call: - { - auto callInst = as<IRCall>(inst); - if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>()) - continue; - auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); - if (!calleeFuncType) continue; - if (calleeFuncType->getParamCount() != callInst->getArgCount()) - continue; - for (UInt a = 0; a < callInst->getArgCount(); a++) { - auto arg = callInst->getArg(a); - auto paramType = calleeFuncType->getParamType(a); - if (!isDifferentiableType(diffTypeContext, paramType)) + auto callInst = as<IRCall>(inst); + if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>()) + continue; + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) continue; - addToExpectDiffWorkList(arg); + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + addToExpectDiffWorkList(arg); + } + break; } - break; - } default: - // Default behavior is to request all differentiable operands to provide differential. + // Default behavior is to request all differentiable operands to provide + // differential. for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) { auto operand = inst->getOperand(opIndex); @@ -542,7 +558,8 @@ public: } if (!hasBackEdge) continue; - if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>()) + if (loop->findDecoration<IRLoopMaxItersDecoration>() || + loop->findDecoration<IRForceUnrollDecoration>()) { // We are good. } @@ -552,9 +569,9 @@ public: } } - // Make sure all stores of differentiable values are into addresses that can hold derivatives. - // If we are assigning a value to a non-differentiable location, we need to make sure - // that value doesn't carray a non-zero diff. + // Make sure all stores of differentiable values are into addresses that can hold + // derivatives. If we are assigning a value to a non-differentiable location, we need to + // make sure that value doesn't carray a non-zero diff. for (auto block : funcInst->getBlocks()) { for (auto inst : block->getChildren()) @@ -564,7 +581,9 @@ public: if (carryNonTrivialDiffSet.contains(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { - sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + sink->diagnose( + storeInst->sourceLoc, + Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); } } else if (auto callInst = as<IRCall>(inst)) @@ -586,7 +605,10 @@ public: { if (!canAddressHoldDerivative(diffTypeContext, arg)) { - sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg); + sink->diagnose( + arg->sourceLoc, + Diagnostics:: + lossOfDerivativeUsingNonDifferentiableLocationAsOutArg); } } } @@ -632,7 +654,8 @@ public: { if (auto genericInst = as<IRGeneric>(inst)) { - if (auto innerFunc = as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst))) + if (auto innerFunc = + as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst))) processFunc(innerFunc); } else if (auto funcInst = as<IRGlobalValueWithCode>(inst)) |
