diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-20 14:42:50 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-20 14:42:50 -0800 |
| commit | 47715e625337d489f3c0131bbc2b849378b48a5a (patch) | |
| tree | bc737c8f03ef537b2ac39860bbb922c7600edc43 | |
| parent | 8b05df4187117d61491f2fdbeb7d744146ad73f7 (diff) | |
Miscellaneous backward autodiff fixes. (#2665)
* Fix differentiable type registration
* Fix use of non-differentiable return value in a differentiable func.
* Fix use of primal inst that does not dominate the diff block.
* Fix primal inst hoisting, and add missing type legalization logic.
* Make `detach` defined on all differentiable T.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
24 files changed, 493 insertions, 144 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a60a77cc3..859b8a488 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -319,7 +319,13 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma // Detach and set derivatives to zero -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> +T detach(T x) +{ + return x; +} + +__generic<T : IDifferentiable> [ForwardDerivativeOf(detach)] DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) { @@ -329,27 +335,13 @@ DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) ); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[ForwardDerivativeOf(detach)] -DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx) -{ - VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); -} - -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[BackwardDerivativeOf(detach)] -void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) -{ - dpx = diffPair(dpx.p, vector<T, N>.dzero()); -} - // Natural Exponent __generic<T : __BuiltinFloatingPointType> diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1d2b327d2..7e75d06b3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -770,26 +770,6 @@ struct TriangleStream // Try to terminate the current draw or dispatch call (HLSL SM 4.0) void abort(); -// Detach and set derivatives to zero - -__generic<T : __BuiltinFloatingPointType> -T detach(T x) -{ - return x; -} - -__generic<T : __BuiltinFloatingPointType, let N : int> -vector<T, N> detach(vector<T, N> x) -{ - return x; -} - -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -matrix<T, N, M> detach(matrix<T, N, M> x) -{ - return x; -} - // Absolute value (HLSL SM 1.0) __generic<T : __BuiltinIntegerType> diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 75a3c2ff1..3567e2593 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1024,6 +1024,23 @@ namespace Slang }); } } + for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer) + { + if (auto genSubst = as<GenericSubstitution>(subst)) + { + for (auto arg : genSubst->getArgs()) + { + if (auto typeArg = as<Type>(arg)) + { + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet); + } + } + } + else if (auto thisSubst = as<ThisTypeSubstitution>(subst)) + { + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet); + } + } return; } } diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 519ca91ff..bc89dc94e 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -585,6 +585,7 @@ namespace Slang void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) { + SLANG_UNUSED(stmt); if (getParentDifferentiableAttribute()) { if (!getParentFunc()) @@ -601,16 +602,6 @@ namespace Slang return; if (getParentFunc()->findModifier<BackwardDerivativeAttribute>()) return; - - // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute, - // or a `[ForceUnroll]` attribet on loops. - if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>()) - { - } - else - { - getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); - } } } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 26d84720f..564c33268 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -488,7 +488,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (!diffReturnType) { - diffReturnType = argBuilder.getVoidType(); + diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); } auto callInst = argBuilder.emitCallInst( @@ -501,7 +501,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig *builder = afterBuilder; - if (diffReturnType->getOp() != kIROp_VoidType) + if (diffReturnType->getOp() == kIROp_DifferentialPairType) { IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); @@ -510,8 +510,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } else { - // Return the inst itself if the return value is void. - // This is fine since these values should never actually be used anywhere. + // Return the inst itself if the return value is non-differentiable. + // This is fine since these values should only be used by non-differentiable code. // return InstPair(callInst, callInst); } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ef6178976..55c0ee46d 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -13,6 +13,7 @@ #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -674,6 +675,106 @@ namespace Slang return fwdDiffFunc; } + void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc) + { + RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc); + auto firstBlock = diffPropFunc->getFirstBlock(); + if (!firstBlock) + return; + Dictionary<IRInst*, IRVar*> instVars; + Dictionary<IRBlock*, IRCloneEnv> cloneEnvs; + auto storeInstAsLocalVar = [&](IRInst* inst) + { + IRVar* var = nullptr; + if (instVars.TryGetValue(inst, var)) + return var; + IRBuilder builder(diffPropFunc); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + var = builder.emitVar(inst->getDataType()); + builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType())); + + setInsertAfterOrdinaryInst(&builder, inst); + builder.emitStore(var, inst); + instVars[inst] = var; + return var; + }; + + IRBuilder builder(diffPropFunc); + List<IRInst*> workList; + for (auto block : diffPropFunc->getBlocks()) + { + if (!block->findDecoration<IRDifferentialInstDecoration>()) + continue; + cloneEnvs[block] = IRCloneEnv(); + for (auto inst : block->getChildren()) + { + workList.add(inst); + } + } + + for (Index i = 0; i < workList.getCount(); i++) + { + auto inst = workList[i]; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto operand = inst->getOperand(j); + if (operand->getOp() == kIROp_Block) + continue; + auto operandParent = inst->getOperand(j)->getParent(); + if (!operandParent) + continue; + if (operandParent->parent != diffPropFunc) + continue; + if (domTree->dominates(operandParent, inst->parent)) + continue; + + // The def site of the operand does not dominate the use. + // We need to insert a local variable to store this var. + + IRInst* operandReplacement = nullptr; + if (canInstBeStored(operand)) + { + auto var = storeInstAsLocalVar(operand); + builder.setInsertBefore(inst); + operandReplacement = builder.emitLoad(var); + } + else if (operand->getOp() == kIROp_Var) + { + // Var can just be hoisted to first block. + operand->insertBefore(firstBlock->getFirstOrdinaryInst()); + } + else + { + // For all other insts, we need to copy it to right before this inst. + // Before actually copying it, check if we have already copied it to + // any blocks that dominates this block. + auto dom = as<IRBlock>(inst->getParent()); + while (dom) + { + auto subCloneEnv = cloneEnvs.TryGetValue(dom); + if (!subCloneEnv) break; + if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement)) + { + break; + } + dom = domTree->getImmediateDominator(dom); + } + // We have not found an existing clone in dominators, so we need to copy it + // to this block. + if (!operandReplacement) + { + auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent())); + builder.setInsertBefore(inst); + operandReplacement = cloneInst(subCloneEnv, &builder, operand); + workList.add(operandReplacement); + } + } + if (operandReplacement) + builder.replaceOperand(inst->getOperands() + j, operandReplacement); + } + } + } + InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) { SLANG_UNUSED(primalType); @@ -838,6 +939,8 @@ namespace Slang initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); + insertVariableForRecomputedPrimalInsts(diffPropagateFunc); + stripTempDecorations(diffPropagateFunc); } ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( @@ -1151,16 +1254,15 @@ namespace Slang while (refUse) { auto nextUse = refUse->nextUse; - switch (refUse->getUser()->getOp()) + // Is this use the dest operand of a store inst? + // If so, replace it with writeRefReplacement, otherwise, refReplacement. + if (refUse->getUser()->getOp() == kIROp_Store && refUse == refUse->getUser()->getOperands()) { - case kIROp_Load: - refUse->set(diffRefReplacement); - break; - case kIROp_Store: + SLANG_RELEASE_ASSERT(diffWriteRefReplacement); refUse->set(diffWriteRefReplacement); - break; - default: - SLANG_RELEASE_ASSERT(!diffWriteRefReplacement); + } + else + { refUse->set(diffRefReplacement); } refUse = nextUse; diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index a638b873c..94bc1ef81 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -114,6 +114,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); + void insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc); + void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc); InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 614559c9f..4e7539b48 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -137,39 +137,8 @@ struct ExtractPrimalFuncContext return false; } - // Only store allowed types. - if (isScalarIntegerType(inst->getDataType())) - { - } - else if (as<IRResourceTypeBase>(inst->getDataType())) - { - } - else - { - switch (inst->getDataType()->getOp()) - { - case kIROp_StructType: - case kIROp_OptionalType: - case kIROp_TupleType: - case kIROp_ArrayType: - case kIROp_DifferentialPairType: - case kIROp_InterfaceType: - case kIROp_AnyValueType: - case kIROp_ClassType: - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - case kIROp_MatrixType: - case kIROp_BoolType: - case kIROp_Param: - case kIROp_Specialize: - case kIROp_LookupWitness: - break; - default: - return false; - } - } + if (!canInstBeStored(inst)) + return false; // Never store certain opcodes. switch (inst->getOp()) @@ -507,11 +476,19 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( else { // Orindary value. - auto val = builder.emitFieldExtract( - inst->getFullType(), - intermediateVar, - structKeyDecor->getStructKey()); - inst->replaceUsesWith(val); + // We insert a fieldExtract at each use site instead of before `inst`, + // since at this stage of autodiff pass, `inst` does not necessarily + // dominate all the use sites if `inst` is defined in partial branch + // in a primal block. + while (auto iuse = inst->firstUse) + { + builder.setInsertBefore(iuse->getUser()); + auto val = builder.emitFieldExtract( + inst->getFullType(), + intermediateVar, + structKeyDecor->getStructKey()); + iuse->set(val); + } } instsToRemove.add(inst); } @@ -529,8 +506,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { inst->removeAndDeallocate(); } - - stripTempDecorations(func); return primalFunc; } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2d5261b63..944df2c81 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -498,34 +498,6 @@ struct DiffUnzipPass } } - void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) - { - if (as<IRParam>(inst)) - { - SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); - auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); - builder->setInsertAfter(lastParam); - } - else - { - builder->setInsertBefore(inst); - } - } - - void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) - { - if (as<IRParam>(inst)) - { - SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); - auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); - builder->setInsertAfter(lastParam); - } - else - { - builder->setInsertAfter(inst); - } - } - void processIndexedFwdBlock(IRBlock* fwdBlock) { if (!isBlockIndexed(fwdBlock)) @@ -794,6 +766,7 @@ struct DiffUnzipPass auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + primalBuilder->markInstAsPrimal(primalVal); SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); @@ -1377,6 +1350,11 @@ struct DiffUnzipPass // Remove insts that were split. for (auto inst : splitInsts) { + if (!isDifferentiableType(diffTypeContext, inst->getDataType())) + { + inst->replaceUsesWith(lookupPrimalInst(inst)); + } + // Consistency check. for (auto use = inst->firstUse; use; use = use->nextUse) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 97cdb644e..b630b798d 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -563,6 +563,30 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* return false; } +bool canInstBeStored(IRInst* inst) +{ + if (as<IRBasicType>(inst->getDataType())) + return true; + + switch (inst->getDataType()->getOp()) + { + case kIROp_StructType: + case kIROp_OptionalType: + case kIROp_TupleType: + case kIROp_ArrayType: + case kIROp_DifferentialPairType: + case kIROp_InterfaceType: + case kIROp_AnyValueType: + case kIROp_ClassType: + case kIROp_FloatType: + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } +} + struct AutoDiffPass : public InstPassBase { DiagnosticSink* getSink() diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index f57fb2974..fa01d50ae 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -303,6 +303,8 @@ bool isBackwardDifferentiableFunc(IRInst* func); bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); +bool canInstBeStored(IRInst* inst); + inline bool isRelevantDifferentialPair(IRType* type) { if (as<IRDifferentialPairType>(type)) diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index c750b2d3d..1ee94e67e 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -353,6 +353,17 @@ public: auto loop = as<IRLoop>(block->getTerminator()); if (!loop) continue; + bool hasBackEdge = false; + for (auto use = loop->getTargetBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() != loop) + { + hasBackEdge = true; + break; + } + } + if (!hasBackEdge) + continue; if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>()) { // We are good. diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 96ffb9bb2..7660c9526 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -701,6 +701,61 @@ static LegalVal legalizeRetVal( return LegalVal(); } +static void _addVal(List<IRInst*>& rs, const LegalVal& legalVal) +{ + switch (legalVal.flavor) + { + case LegalVal::Flavor::simple: + rs.add(legalVal.getSimple()); + break; + case LegalVal::Flavor::tuple: + for (auto element : legalVal.getTuple()->elements) + _addVal(rs, element.val); + break; + case LegalVal::Flavor::pair: + _addVal(rs, legalVal.getPair()->ordinaryVal); + _addVal(rs, legalVal.getPair()->specialVal); + break; + case LegalVal::Flavor::none: + break; + default: + SLANG_UNEXPECTED("unhandled legalized val flavor"); + } +} + +static LegalVal legalizeUnconditionalBranch( + IRTypeLegalizationContext* context, + ArrayView<LegalVal> args, + IRUnconditionalBranch* branchInst) +{ + List<IRInst*> newArgs; + for (auto arg : args) + { + switch (arg.flavor) + { + case LegalVal::Flavor::none: + break; + case LegalVal::Flavor::simple: + newArgs.add(arg.getSimple()); + break; + case LegalVal::Flavor::pair: + _addVal(newArgs, arg.getPair()->ordinaryVal); + _addVal(newArgs, arg.getPair()->specialVal); + break; + case LegalVal::Flavor::tuple: + for (auto element : arg.getTuple()->elements) + { + _addVal(newArgs, element.val); + } + break; + default: + SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor."); + } + } + context->builder->emitBranch(branchInst->getTargetBlock(), newArgs.getCount() - 1, newArgs.getBuffer() + 1); + return LegalVal(); +} + static LegalVal legalizeLoad( IRTypeLegalizationContext* context, LegalVal legalPtrVal) @@ -1610,11 +1665,69 @@ static LegalVal legalizeMakeStruct( } } +static LegalVal legalizeDefaultConstruct( + IRTypeLegalizationContext* context, + LegalType legalType) +{ + auto builder = context->builder; + + switch (legalType.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + + case LegalType::Flavor::simple: + { + return LegalVal::simple( + builder->emitDefaultConstruct(legalType.getSimple())); + } + + case LegalType::Flavor::pair: + { + auto pairType = legalType.getPair(); + auto pairInfo = pairType->pairInfo; + LegalType ordinaryType = pairType->ordinaryType; + LegalType specialType = pairType->specialType; + + LegalVal ordinaryVal = legalizeDefaultConstruct( + context, + ordinaryType); + + LegalVal specialVal = legalizeDefaultConstruct( + context, + specialType); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalType::Flavor::tuple: + { + auto tupleType = legalType.getTuple(); + + RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal(); + for (auto typeElem : tupleType->elements) + { + auto elemKey = typeElem.key; + TuplePseudoVal::Element resElem; + resElem.key = elemKey; + resElem.val = legalizeDefaultConstruct(context, typeElem.type); + resTupleInfo->elements.add(resElem); + } + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst, LegalType type, - LegalVal const* args) + ArrayView<LegalVal> args) { switch (inst->getOp()) { @@ -1647,8 +1760,14 @@ static LegalVal legalizeInst( return legalizeMakeStruct( context, type, - args, + args.getBuffer(), inst->getOperandCount()); + case kIROp_DefaultConstruct: + return legalizeDefaultConstruct( + context, + type); + case kIROp_unconditionalBranch: + return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst); case kIROp_undefined: return LegalVal(); case kIROp_GpuForeach: @@ -1896,7 +2015,7 @@ static LegalVal legalizeInst( context, inst, legalType, - legalArgs.getBuffer()); + legalArgs.getArrayView()); if (legalVal.flavor == LegalVal::Flavor::simple) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f3c4c2c82..3db036a8d 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -224,7 +224,7 @@ String dumpIRToString(IRInst* root) StringBuilder sb; StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); IRDumpOptions options = {}; -#if 0 +#if 1 options.flags = IRDumpOptions::Flag::DumpDebugIds; #endif dumpIR(root, options, nullptr, &writer); @@ -487,6 +487,34 @@ IROp getSwapSideComparisonOp(IROp op) } } +void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) +{ + if (as<IRParam>(inst)) + { + SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); + auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); + } + else + { + builder->setInsertBefore(inst); + } +} + +void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) +{ + if (as<IRParam>(inst)) + { + SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); + auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); + } + else + { + builder->setInsertAfter(inst); + } +} + bool isPureFunctionalCall(IRCall* call) { auto callee = getResolvedInstForDecorations(call->getCallee()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0fb26f791..8a12ab895 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -173,6 +173,12 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module); // The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>). IROp getSwapSideComparisonOp(IROp op); +// Set IRBuilder to insert before `inst`. If `inst` is a param, it will insert after the last param. +void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst); + +// Set IRBuilder to insert after `inst`. If `inst` is a param, it will insert after the last param. +void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index dc61a45af..accefc0c9 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5662,9 +5662,9 @@ namespace Slang #if SLANG_ENABLE_IR_BREAK_ALLOC if (context->options.flags & IRDumpOptions::Flag::DumpDebugIds) { - dump(context, "[#"); + dump(context, "{"); dump(context, String(inst->_debugUID)); - dump(context, "]"); + dump(context, "}\t"); } #else SLANG_UNUSED(context); @@ -5691,7 +5691,6 @@ namespace Slang { dump(context, "_"); } - dumpDebugID(context, inst); } static void dumpEncodeString( @@ -5819,6 +5818,7 @@ namespace Slang IRBlock* block) { context->indent--; + dumpDebugID(context, block); dump(context, "block "); dumpID(context, block); @@ -6050,7 +6050,6 @@ namespace Slang } dump(context, opInfo.name); - dumpDebugID(context, inst); dumpInstOperandList(context, inst); } @@ -6068,6 +6067,8 @@ namespace Slang dumpIRDecorations(context, inst); + dumpDebugID(context, inst); + // There are several ops we want to special-case here, // so that they will be more pleasant to look at. // @@ -6204,7 +6205,10 @@ namespace Slang context.options = options; context.sourceManager = sourceManager; - dumpInst(&context, globalVal); + if (globalVal->getOp() == kIROp_Module) + dumpIRModule(&context, globalVal->getModule()); + else + dumpInst(&context, globalVal); writer->write(sb.getBuffer(), sb.getLength()); writer->flush(); diff --git a/tests/autodiff/bool-return-control-flow.slang b/tests/autodiff/bool-return-control-flow.slang new file mode 100644 index 000000000..9dd398a89 --- /dev/null +++ b/tests/autodiff/bool-return-control-flow.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +bool conditionFunc(no_diff float a, inout float x) +{ + x = x * a; + return x > 100.f; +} + +[BackwardDifferentiable] +float outerFunc(no_diff float a, float x) +{ + if (conditionFunc(a, x)) + return x; + else + return -x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float a = 10.0; + DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f); + __bwd_diff(outerFunc)(a, dpx, 1.0); + + outputBuffer[0] = dpx.d; +}
\ No newline at end of file diff --git a/tests/autodiff/bool-return-control-flow.slang.expected.txt b/tests/autodiff/bool-return-control-flow.slang.expected.txt new file mode 100644 index 000000000..e2789a12c --- /dev/null +++ b/tests/autodiff/bool-return-control-flow.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +-10.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-2.slang b/tests/autodiff/reverse-control-flow-2.slang new file mode 100644 index 000000000..cde707b4d --- /dev/null +++ b/tests/autodiff/reverse-control-flow-2.slang @@ -0,0 +1,75 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +bool doWork(float x, out float y) +{ + bool retVal = false; + y = 0; + for (;;) + { + if (x == 0.0) + break; + + bool exited = (x == 1.0); + + y += x; + + if (!exited) + { + if (x < 1.0) + { + float b = x * 2.0f; + y += b; + exited = true; + } + } + retVal = true; + break; + } + return retVal; +} + +[BackwardDifferentiable] +bool doWork2(float x, out float y) +{ + y = 0; + + if (x == 0.0) return false; + + [ForceUnroll] + for (int i = 0; i < 2; ++i) + { + if (x > 0.0) + { + y += x; + + if (x == 1.0) break; + + y += x; + } + else + { + y += x; + } + } + return true; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + var dpx = diffPair(0.5f, 1.0f); + __bwd_diff(doWork)(dpx, 1.0f); + outputBuffer[0] = dpx.d; + } + { + var dpx = diffPair(0.5f, 0.0f); + __bwd_diff(doWork2)(dpx, 1.0); + outputBuffer[1] = dpx.d; + } +} diff --git a/tests/autodiff/reverse-control-flow-2.slang.expected.txt b/tests/autodiff/reverse-control-flow-2.slang.expected.txt new file mode 100644 index 000000000..31023ed32 --- /dev/null +++ b/tests/autodiff/reverse-control-flow-2.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +3.000000 +4.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/diagnostics/autodiff-data-flow-2.slang b/tests/diagnostics/autodiff-data-flow-2.slang index aa923c5d6..3148c6a41 100644 --- a/tests/diagnostics/autodiff-data-flow-2.slang +++ b/tests/diagnostics/autodiff-data-flow-2.slang @@ -24,5 +24,11 @@ float h(float x) float val = 0; // no diagnostic by clarifying intention. val = no_diff(f(x + 1)); + + // error: dynamic loop without [MaxIters] or [ForceUnroll] + for (int i = 0; i < (int)x; i++) + { + } + return val; } diff --git a/tests/diagnostics/autodiff-data-flow-2.slang.expected b/tests/diagnostics/autodiff-data-flow-2.slang.expected index 9026c0748..725a27c9c 100644 --- a/tests/diagnostics/autodiff-data-flow-2.slang.expected +++ b/tests/diagnostics/autodiff-data-flow-2.slang.expected @@ -1,8 +1,11 @@ result code = -1 standard error = { -tests/diagnostics/autodiff-data-flow-2.slang(18): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `f`, use 'no_diff' to clarify intention. +tests/diagnostics/autodiff-data-flow-2.slang(17): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `f`, use 'no_diff' to clarify intention. float val = f(x + 1); // Error: f must also be backward-differentiable ^ +tests/diagnostics/autodiff-data-flow-2.slang(29): error 30510: loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute. + for (int i = 0; i < (int)x; i++) + ^~~ } standard output = { } diff --git a/tests/diagnostics/autodiff.slang b/tests/diagnostics/autodiff.slang index 935ef07cb..7905b48b6 100644 --- a/tests/diagnostics/autodiff.slang +++ b/tests/diagnostics/autodiff.slang @@ -12,11 +12,6 @@ float f(float x) if (x > 5) val = x + 1; - // warning: dynamic loop without [MaxIters] or [ForceUnroll] - for (int i = 0; i < (int)x; i++) - { - } - [MaxIters(2)] for (int i = 0; i < (int)x; i++) // OK { diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected index 952503d1c..d075b9406 100644 --- a/tests/diagnostics/autodiff.slang.expected +++ b/tests/diagnostics/autodiff.slang.expected @@ -1,8 +1,5 @@ result code = -1 standard error = { -tests/diagnostics/autodiff.slang(16): error 30510: loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute. - for (int i = 0; i < (int)x; i++) - ^~~ tests/diagnostics/autodiff.slang(35): error 38031: 'no_diff' can only be used to decorate a call. float x1 = no_diff x; // invalid use of no_diff here. ^~~~~~~ |
