summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp421
1 files changed, 421 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
new file mode 100644
index 000000000..44c6324e3
--- /dev/null
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -0,0 +1,421 @@
+#include "slang-ir-check-differentiability.h"
+
+#include "slang-ir-diff-jvp.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+struct CheckDifferentiabilityPassContext : public InstPassBase
+{
+public:
+ DiagnosticSink* sink;
+ AutoDiffSharedContext sharedContext;
+
+ HashSet<IRInst*> differentiableFunctions;
+
+ CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
+ : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
+ {}
+
+ IRInst* getSpecializedVal(IRInst* inst)
+ {
+ int loopLimit = 1024;
+ while (inst && inst->getOp() == kIROp_Specialize)
+ {
+ inst = as<IRSpecialize>(inst)->getBase();
+ loopLimit--;
+ if (loopLimit == 0)
+ return inst;
+ }
+ return inst;
+ }
+
+ IRInst* getLeafFunc(IRInst* func)
+ {
+ func = getSpecializedVal(func);
+ if (!func)
+ return nullptr;
+ if (auto genericFunc = as<IRGeneric>(func))
+ return findInnerMostGenericReturnVal(genericFunc);
+ return func;
+ }
+
+ bool _isFuncMarkedForAutoDiff(IRInst* func)
+ {
+ func = getLeafFunc(func);
+ if (!func)
+ return false;
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+
+ bool _isDifferentiableFuncImpl(IRInst* func)
+ {
+ func = getLeafFunc(func);
+ if (!func)
+ return false;
+
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool isDifferentiableFunc(IRInst* func)
+ {
+ switch (func->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return true;
+ default:
+ break;
+ }
+
+ func = getSpecializedVal(func);
+ if (!func)
+ return false;
+
+ if (differentiableFunctions.Contains(func))
+ return true;
+
+ for (; func; func = func->parent)
+ {
+ if (as<IRGeneric>(func))
+ {
+ return differentiableFunctions.Contains(func);
+ }
+ }
+ return false;
+ }
+
+ bool isBackwardDifferentiableFunc(IRInst* func)
+ {
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+ {
+ HashSet<IRInst*> processedSet;
+ while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = ptrType->getValueType();
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ int getParamIndexInBlock(IRParam* paramInst)
+ {
+ auto block = as<IRBlock>(paramInst->getParent());
+ if (!block)
+ return -1;
+ int paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ if (param == paramInst)
+ return paramIndex;
+ paramIndex++;
+ }
+ return -1;
+ }
+
+ bool isInstInFunc(IRInst* inst, IRInst* func)
+ {
+ while (inst)
+ {
+ if (inst == func)
+ return true;
+ inst = inst->parent;
+ }
+ return false;
+ }
+ void processFunc(IRGlobalValueWithCode* funcInst)
+ {
+ if (!_isFuncMarkedForAutoDiff(funcInst))
+ return;
+ if (!funcInst->getFirstBlock())
+ return;
+
+ DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
+ diffTypeContext.setFunc(funcInst);
+
+ HashSet<IRInst*> produceDiffSet;
+ HashSet<IRInst*> expectDiffSet;
+ int differentiableInputs = 0;
+ int differentiableOutputs = 0;
+ for (auto param : funcInst->getFirstBlock()->getParams())
+ {
+ if (isDifferentiableType(diffTypeContext, param->getFullType()))
+ {
+ if (as<IROutTypeBase>(param->getFullType()))
+ differentiableOutputs++;
+ if (!as<IROutType>(param->getFullType()))
+ differentiableInputs++;
+ produceDiffSet.Add(param);
+ }
+ }
+ if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
+ {
+ if (isDifferentiableType(diffTypeContext, funcType->getResultType()))
+ differentiableOutputs++;
+ }
+
+ if (differentiableOutputs == 0)
+ sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput);
+ if (differentiableInputs == 0)
+ sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput);
+
+ auto isInstProducingDiff = [&](IRInst* inst) -> bool
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_FloatLit:
+ return true;
+ case kIROp_Call:
+ return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ case kIROp_Load:
+ // We don't have more knowledge on whether diff is available at the destination address.
+ // Just assume it is producing diff.
+ //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
+ return isDifferentiableType(diffTypeContext, inst->getDataType());
+ default:
+ // default case is to assume the inst produces a diff value if any
+ // of its operands produces a diff value.
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (produceDiffSet.Contains(inst->getOperand(i)))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+ };
+
+ List<IRInst*> expectDiffInstWorkList;
+ OrderedHashSet<IRInst*> expectDiffInstWorkListSet;
+ auto addToExpectDiffWorkList = [&](IRInst* inst)
+ {
+ if (isInstInFunc(inst, funcInst))
+ {
+ if (expectDiffInstWorkListSet.Add(inst))
+ expectDiffInstWorkList.add(inst);
+ }
+ };
+ // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`.
+ Index lastProduceDiffCount = 0;
+ do
+ {
+ lastProduceDiffCount = produceDiffSet.Count();
+ for (auto block : funcInst->getBlocks())
+ {
+ if (block != funcInst->getFirstBlock())
+ {
+ UInt paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ for (auto p : block->getPredecessors())
+ {
+ // A Phi Node is producing diff if any of its candidate values are producing diff.
+ if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
+ {
+ if (branch->getArgCount() > paramIndex)
+ {
+ auto arg = branch->getArg(paramIndex);
+ if (produceDiffSet.Contains(arg))
+ {
+ produceDiffSet.Add(param);
+ break;
+ }
+ }
+ }
+ }
+ paramIndex++;
+ }
+ }
+ for (auto inst : block->getChildren())
+ {
+ if (isInstProducingDiff(inst))
+ produceDiffSet.Add(inst);
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ if (isDifferentiableFunc(as<IRCall>(inst)->getCallee()))
+ {
+ addToExpectDiffWorkList(inst);
+ }
+ break;
+ case kIROp_Store:
+ {
+ auto storeInst = as<IRStore>(inst);
+ if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
+ {
+ addToExpectDiffWorkList(storeInst->getVal());
+ }
+ }
+ break;
+ case kIROp_Return:
+ if (auto returnVal = as<IRReturn>(inst)->getVal())
+ {
+ if (isDifferentiableType(diffTypeContext, returnVal->getDataType()))
+ {
+ addToExpectDiffWorkList(inst);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ } while (produceDiffSet.Count() != lastProduceDiffCount);
+
+ // Reverse propagate `expectDiffSet`.
+ for (int i = 0; i < expectDiffInstWorkList.getCount(); i++)
+ {
+ auto inst = expectDiffInstWorkList[i];
+ // Is inst in produceDiffSet?
+ if (!produceDiffSet.Contains(inst))
+ {
+ if (auto call = as<IRCall>(inst))
+ {
+ sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee()));
+ }
+ }
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ {
+ auto block = as<IRBlock>(inst->getParent());
+ if (block != funcInst->getFirstBlock())
+ {
+ auto paramIndex = getParamIndexInBlock(as<IRParam>(inst));
+ if (paramIndex != -1)
+ {
+ for (auto p : block->getPredecessors())
+ {
+ // A Phi Node is producing diff if any of its candidate values are producing diff.
+ if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
+ {
+ if (branch->getArgCount() > (UInt)paramIndex)
+ {
+ auto arg = branch->getArg(paramIndex);
+ addToExpectDiffWorkList(arg);
+ }
+ }
+ }
+ }
+ }
+ break;
+ }
+ default:
+ // Default behavior is to request all differentiable operands to provide differential.
+ for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++)
+ {
+ auto operand = inst->getOperand(opIndex);
+ if (isDifferentiableType(diffTypeContext, operand->getFullType()))
+ {
+ addToExpectDiffWorkList(operand);
+ }
+ }
+ }
+ }
+ }
+
+ void processModule()
+ {
+ // Collect set of differentiable functions.
+ HashSet<UnownedStringSlice> differentiableSymbolNames;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (_isDifferentiableFuncImpl(inst))
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ differentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst);
+ }
+ }
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ {
+ if (differentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions.Add(inst);
+ }
+ }
+
+ if (!sharedContext.isInterfaceAvailable)
+ return;
+
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ if (auto innerFunc = as<IRGlobalValueWithCode>(findGenericReturnVal(genericInst)))
+ processFunc(innerFunc);
+ }
+ else if (auto funcInst = as<IRGlobalValueWithCode>(inst))
+ {
+ processFunc(funcInst);
+ }
+ }
+ }
+};
+
+void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink)
+{
+ CheckDifferentiabilityPassContext context(module, sink);
+ context.processModule();
+}
+
+} // namespace Slang