#include "slang-ir-check-differentiability.h" #include "slang-ir-autodiff.h" #include "slang-ir-inst-pass-base.h" namespace Slang { struct CheckDifferentiabilityPassContext : public InstPassBase { public: DiagnosticSink* sink; AutoDiffSharedContext sharedContext; HashSet differentiableFunctions; CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst()) {} IRInst* getSpecializedVal(IRInst* inst) { int loopLimit = 1024; while (inst && inst->getOp() == kIROp_Specialize) { inst = as(inst)->getBase(); loopLimit--; if (loopLimit == 0) return inst; } return inst; } IRInst* getLeafFunc(IRInst* func) { func = getSpecializedVal(func); if (!func) return nullptr; if (auto genericFunc = as(func)) return findInnerMostGenericReturnVal(genericFunc); return func; } bool _isFuncMarkedForAutoDiff(IRInst* func) { func = getLeafFunc(func); if (!func) return false; for (auto decorations : func->getDecorations()) { switch (decorations->getOp()) { case kIROp_ForwardDifferentiableDecoration: case kIROp_BackwardDifferentiableDecoration: return true; } } return false; } bool _isDifferentiableFuncImpl(IRInst* func) { func = getLeafFunc(func); if (!func) return false; for (auto decorations : func->getDecorations()) { switch (decorations->getOp()) { case kIROp_ForwardDerivativeDecoration: case kIROp_ForwardDifferentiableDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDifferentiableDecoration: return true; } } return false; } bool isDifferentiableFunc(IRInst* func) { switch (func->getOp()) { case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: return true; default: break; } func = getSpecializedVal(func); if (!func) return false; if (differentiableFunctions.Contains(func)) return true; if (func->findDecoration()) return true; if (auto lookupInterfaceMethod = as(func)) { auto wit = lookupInterfaceMethod->getWitnessTable(); if (!wit) return false; auto witType = as(wit->getDataType()); if (!witType) return false; auto interfaceType = witType->getConformanceType(); if (!interfaceType) return false; if (interfaceType->findDecoration()) return true; if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) return true; auto dictDecor = interfaceType->findDecoration(); if (!dictDecor) return false; for (auto child : dictDecor->getChildren()) { if (auto entry = as(child)) { if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) { return true; } } } } for (; func; func = func->parent) { if (as(func)) { return differentiableFunctions.Contains(func); } } return false; } bool isBackwardDifferentiableFunc(IRInst* func) { for (auto decorations : func->getDecorations()) { switch (decorations->getOp()) { case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDifferentiableDecoration: return true; } } return false; } bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) { HashSet processedSet; while (auto ptrType = as(typeInst)) { typeInst = ptrType->getValueType(); if (!processedSet.Add(typeInst)) return false; } if (!typeInst) return false; switch (typeInst->getOp()) { case kIROp_FloatType: case kIROp_DifferentialPairType: return true; default: break; } if (context.lookUpConformanceForType(typeInst)) return true; // Look for equivalent types. for (auto type : context.differentiableWitnessDictionary) { if (isTypeEqual(type.Key, (IRType*)typeInst)) { context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; return true; } } return false; } int getParamIndexInBlock(IRParam* paramInst) { auto block = as(paramInst->getParent()); if (!block) return -1; int paramIndex = 0; for (auto param : block->getParams()) { if (param == paramInst) return paramIndex; paramIndex++; } return -1; } bool isInstInFunc(IRInst* inst, IRInst* func) { while (inst) { if (inst == func) return true; inst = inst->parent; } return false; } void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) return; if (!funcInst->getFirstBlock()) return; DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); diffTypeContext.setFunc(funcInst); HashSet produceDiffSet; HashSet expectDiffSet; int differentiableInputs = 0; int differentiableOutputs = 0; for (auto param : funcInst->getFirstBlock()->getParams()) { if (isDifferentiableType(diffTypeContext, param->getFullType())) { if (as(param->getFullType())) differentiableOutputs++; if (!as(param->getFullType())) differentiableInputs++; produceDiffSet.Add(param); } } if (auto funcType = as(funcInst->getDataType())) { if (isDifferentiableType(diffTypeContext, funcType->getResultType())) differentiableOutputs++; } if (differentiableOutputs == 0) sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput); if (differentiableInputs == 0) sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput); auto isInstProducingDiff = [&](IRInst* inst) -> bool { switch (inst->getOp()) { case kIROp_FloatLit: return true; case kIROp_Call: return inst->findDecoration() || isDifferentiableFunc(as(inst)->getCallee()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff. //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. return isDifferentiableType(diffTypeContext, inst->getDataType()); default: // default case is to assume the inst produces a diff value if any // of its operands produces a diff value. for (UInt i = 0; i < inst->getOperandCount(); i++) { if (produceDiffSet.Contains(inst->getOperand(i))) { return true; } } return false; } }; List expectDiffInstWorkList; OrderedHashSet expectDiffInstWorkListSet; auto addToExpectDiffWorkList = [&](IRInst* inst) { if (isInstInFunc(inst, funcInst)) { if (expectDiffInstWorkListSet.Add(inst)) expectDiffInstWorkList.add(inst); } }; // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`. Index lastProduceDiffCount = 0; do { lastProduceDiffCount = produceDiffSet.Count(); for (auto block : funcInst->getBlocks()) { if (block != funcInst->getFirstBlock()) { UInt paramIndex = 0; for (auto param : block->getParams()) { for (auto p : block->getPredecessors()) { // A Phi Node is producing diff if any of its candidate values are producing diff. if (auto branch = as(p->getTerminator())) { if (branch->getArgCount() > paramIndex) { auto arg = branch->getArg(paramIndex); if (produceDiffSet.Contains(arg)) { produceDiffSet.Add(param); break; } } } } paramIndex++; } } for (auto inst : block->getChildren()) { if (isInstProducingDiff(inst)) produceDiffSet.Add(inst); switch (inst->getOp()) { case kIROp_Call: if (isDifferentiableFunc(as(inst)->getCallee())) { addToExpectDiffWorkList(inst); } break; case kIROp_Store: { auto storeInst = as(inst); if (isDifferentiableType(diffTypeContext, as(inst)->getPtr()->getDataType())) { addToExpectDiffWorkList(storeInst->getVal()); } } break; case kIROp_Return: if (auto returnVal = as(inst)->getVal()) { if (isDifferentiableType(diffTypeContext, returnVal->getDataType())) { addToExpectDiffWorkList(inst); } } break; default: break; } } } } while (produceDiffSet.Count() != lastProduceDiffCount); // Reverse propagate `expectDiffSet`. for (int i = 0; i < expectDiffInstWorkList.getCount(); i++) { auto inst = expectDiffInstWorkList[i]; // Is inst in produceDiffSet? if (!produceDiffSet.Contains(inst)) { if (auto call = as(inst)) { sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee())); } } switch (inst->getOp()) { case kIROp_Param: { auto block = as(inst->getParent()); if (block != funcInst->getFirstBlock()) { auto paramIndex = getParamIndexInBlock(as(inst)); if (paramIndex != -1) { for (auto p : block->getPredecessors()) { // A Phi Node is producing diff if any of its candidate values are producing diff. if (auto branch = as(p->getTerminator())) { if (branch->getArgCount() > (UInt)paramIndex) { auto arg = branch->getArg(paramIndex); addToExpectDiffWorkList(arg); } } } } } break; } default: // Default behavior is to request all differentiable operands to provide differential. for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) { auto operand = inst->getOperand(opIndex); if (isDifferentiableType(diffTypeContext, operand->getFullType())) { addToExpectDiffWorkList(operand); } } } } } void processModule() { // Collect set of differentiable functions. HashSet differentiableSymbolNames; for (auto inst : module->getGlobalInsts()) { if (_isDifferentiableFuncImpl(inst)) { if (auto linkageDecor = inst->findDecoration()) differentiableSymbolNames.Add(linkageDecor->getMangledName()); differentiableFunctions.Add(inst); } } for (auto inst : module->getGlobalInsts()) { if (auto linkageDecor = inst->findDecoration()) { if (differentiableSymbolNames.Contains(linkageDecor->getMangledName())) differentiableFunctions.Add(inst); } } if (!sharedContext.isInterfaceAvailable) return; for (auto inst : module->getGlobalInsts()) { if (auto genericInst = as(inst)) { if (auto innerFunc = as(findGenericReturnVal(genericInst))) processFunc(innerFunc); } else if (auto funcInst = as(inst)) { processFunc(funcInst); } } } }; void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink) { CheckDifferentiabilityPassContext context(module, sink); context.processModule(); } } // namespace Slang