From 47715e625337d489f3c0131bbc2b849378b48a5a Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 20 Feb 2023 14:42:50 -0800 Subject: Miscellaneous backward autodiff fixes. (#2665) * Fix differentiable type registration * Fix use of non-differentiable return value in a differentiable func. * Fix use of primal inst that does not dominate the diff block. * Fix primal inst hoisting, and add missing type legalization logic. * Make `detach` defined on all differentiable T. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 24 ++--- source/slang/hlsl.meta.slang | 20 ---- source/slang/slang-check-expr.cpp | 17 +++ source/slang/slang-check-stmt.cpp | 11 +- source/slang/slang-ir-autodiff-fwd.cpp | 8 +- source/slang/slang-ir-autodiff-rev.cpp | 118 ++++++++++++++++++-- source/slang/slang-ir-autodiff-rev.h | 2 + source/slang/slang-ir-autodiff-unzip.cpp | 55 +++------- source/slang/slang-ir-autodiff-unzip.h | 34 ++---- source/slang/slang-ir-autodiff.cpp | 24 +++++ source/slang/slang-ir-autodiff.h | 2 + source/slang/slang-ir-check-differentiability.cpp | 11 ++ source/slang/slang-ir-legalize-types.cpp | 125 +++++++++++++++++++++- source/slang/slang-ir-util.cpp | 30 +++++- source/slang/slang-ir-util.h | 6 ++ source/slang/slang-ir.cpp | 14 ++- 16 files changed, 366 insertions(+), 135 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a60a77cc3..859b8a488 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -319,7 +319,13 @@ void mul(inout DifferentialPair> left, inout DifferentialPair +__generic +T detach(T x) +{ + return x; +} + +__generic [ForwardDerivativeOf(detach)] DifferentialPair __d_detach(DifferentialPair dpx) { @@ -329,27 +335,13 @@ DifferentialPair __d_detach(DifferentialPair dpx) ); } -__generic -[ForwardDerivativeOf(detach)] -DifferentialPair> __d_detach_vector(DifferentialPair> dpx) -{ - VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); -} - -__generic +__generic [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } -__generic -[BackwardDerivativeOf(detach)] -void __d_detach_vector(inout DifferentialPair> dpx, vector.Differential dOut) -{ - dpx = diffPair(dpx.p, vector.dzero()); -} - // Natural Exponent __generic diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1d2b327d2..7e75d06b3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -770,26 +770,6 @@ struct TriangleStream // Try to terminate the current draw or dispatch call (HLSL SM 4.0) void abort(); -// Detach and set derivatives to zero - -__generic -T detach(T x) -{ - return x; -} - -__generic -vector detach(vector x) -{ - return x; -} - -__generic -matrix detach(matrix x) -{ - return x; -} - // Absolute value (HLSL SM 1.0) __generic diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 75a3c2ff1..3567e2593 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1024,6 +1024,23 @@ namespace Slang }); } } + for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer) + { + if (auto genSubst = as(subst)) + { + for (auto arg : genSubst->getArgs()) + { + if (auto typeArg = as(arg)) + { + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet); + } + } + } + else if (auto thisSubst = as(subst)) + { + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet); + } + } return; } } diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 519ca91ff..bc89dc94e 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -585,6 +585,7 @@ namespace Slang void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) { + SLANG_UNUSED(stmt); if (getParentDifferentiableAttribute()) { if (!getParentFunc()) @@ -601,16 +602,6 @@ namespace Slang return; if (getParentFunc()->findModifier()) return; - - // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute, - // or a `[ForceUnroll]` attribet on loops. - if (stmt->hasModifier() || stmt->hasModifier() || stmt->hasModifier()) - { - } - else - { - getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); - } } } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 26d84720f..564c33268 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -488,7 +488,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (!diffReturnType) { - diffReturnType = argBuilder.getVoidType(); + diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); } auto callInst = argBuilder.emitCallInst( @@ -501,7 +501,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig *builder = afterBuilder; - if (diffReturnType->getOp() != kIROp_VoidType) + if (diffReturnType->getOp() == kIROp_DifferentialPairType) { IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); @@ -510,8 +510,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } else { - // Return the inst itself if the return value is void. - // This is fine since these values should never actually be used anywhere. + // Return the inst itself if the return value is non-differentiable. + // This is fine since these values should only be used by non-differentiable code. // return InstPair(callInst, callInst); } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ef6178976..55c0ee46d 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -13,6 +13,7 @@ #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -674,6 +675,106 @@ namespace Slang return fwdDiffFunc; } + void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc) + { + RefPtr domTree = computeDominatorTree(diffPropFunc); + auto firstBlock = diffPropFunc->getFirstBlock(); + if (!firstBlock) + return; + Dictionary instVars; + Dictionary cloneEnvs; + auto storeInstAsLocalVar = [&](IRInst* inst) + { + IRVar* var = nullptr; + if (instVars.TryGetValue(inst, var)) + return var; + IRBuilder builder(diffPropFunc); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + var = builder.emitVar(inst->getDataType()); + builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType())); + + setInsertAfterOrdinaryInst(&builder, inst); + builder.emitStore(var, inst); + instVars[inst] = var; + return var; + }; + + IRBuilder builder(diffPropFunc); + List workList; + for (auto block : diffPropFunc->getBlocks()) + { + if (!block->findDecoration()) + continue; + cloneEnvs[block] = IRCloneEnv(); + for (auto inst : block->getChildren()) + { + workList.add(inst); + } + } + + for (Index i = 0; i < workList.getCount(); i++) + { + auto inst = workList[i]; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto operand = inst->getOperand(j); + if (operand->getOp() == kIROp_Block) + continue; + auto operandParent = inst->getOperand(j)->getParent(); + if (!operandParent) + continue; + if (operandParent->parent != diffPropFunc) + continue; + if (domTree->dominates(operandParent, inst->parent)) + continue; + + // The def site of the operand does not dominate the use. + // We need to insert a local variable to store this var. + + IRInst* operandReplacement = nullptr; + if (canInstBeStored(operand)) + { + auto var = storeInstAsLocalVar(operand); + builder.setInsertBefore(inst); + operandReplacement = builder.emitLoad(var); + } + else if (operand->getOp() == kIROp_Var) + { + // Var can just be hoisted to first block. + operand->insertBefore(firstBlock->getFirstOrdinaryInst()); + } + else + { + // For all other insts, we need to copy it to right before this inst. + // Before actually copying it, check if we have already copied it to + // any blocks that dominates this block. + auto dom = as(inst->getParent()); + while (dom) + { + auto subCloneEnv = cloneEnvs.TryGetValue(dom); + if (!subCloneEnv) break; + if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement)) + { + break; + } + dom = domTree->getImmediateDominator(dom); + } + // We have not found an existing clone in dominators, so we need to copy it + // to this block. + if (!operandReplacement) + { + auto subCloneEnv = cloneEnvs.TryGetValue(as(inst->getParent())); + builder.setInsertBefore(inst); + operandReplacement = cloneInst(subCloneEnv, &builder, operand); + workList.add(operandReplacement); + } + } + if (operandReplacement) + builder.replaceOperand(inst->getOperands() + j, operandReplacement); + } + } + } + InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) { SLANG_UNUSED(primalType); @@ -838,6 +939,8 @@ namespace Slang initializeLocalVariables(builder->getModule(), as(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); + insertVariableForRecomputedPrimalInsts(diffPropagateFunc); + stripTempDecorations(diffPropagateFunc); } ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( @@ -1151,16 +1254,15 @@ namespace Slang while (refUse) { auto nextUse = refUse->nextUse; - switch (refUse->getUser()->getOp()) + // Is this use the dest operand of a store inst? + // If so, replace it with writeRefReplacement, otherwise, refReplacement. + if (refUse->getUser()->getOp() == kIROp_Store && refUse == refUse->getUser()->getOperands()) { - case kIROp_Load: - refUse->set(diffRefReplacement); - break; - case kIROp_Store: + SLANG_RELEASE_ASSERT(diffWriteRefReplacement); refUse->set(diffWriteRefReplacement); - break; - default: - SLANG_RELEASE_ASSERT(!diffWriteRefReplacement); + } + else + { refUse->set(diffRefReplacement); } refUse = nextUse; diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index a638b873c..94bc1ef81 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -114,6 +114,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); + void insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc); + void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc); InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 614559c9f..4e7539b48 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -137,39 +137,8 @@ struct ExtractPrimalFuncContext return false; } - // Only store allowed types. - if (isScalarIntegerType(inst->getDataType())) - { - } - else if (as(inst->getDataType())) - { - } - else - { - switch (inst->getDataType()->getOp()) - { - case kIROp_StructType: - case kIROp_OptionalType: - case kIROp_TupleType: - case kIROp_ArrayType: - case kIROp_DifferentialPairType: - case kIROp_InterfaceType: - case kIROp_AnyValueType: - case kIROp_ClassType: - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - case kIROp_MatrixType: - case kIROp_BoolType: - case kIROp_Param: - case kIROp_Specialize: - case kIROp_LookupWitness: - break; - default: - return false; - } - } + if (!canInstBeStored(inst)) + return false; // Never store certain opcodes. switch (inst->getOp()) @@ -507,11 +476,19 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( else { // Orindary value. - auto val = builder.emitFieldExtract( - inst->getFullType(), - intermediateVar, - structKeyDecor->getStructKey()); - inst->replaceUsesWith(val); + // We insert a fieldExtract at each use site instead of before `inst`, + // since at this stage of autodiff pass, `inst` does not necessarily + // dominate all the use sites if `inst` is defined in partial branch + // in a primal block. + while (auto iuse = inst->firstUse) + { + builder.setInsertBefore(iuse->getUser()); + auto val = builder.emitFieldExtract( + inst->getFullType(), + intermediateVar, + structKeyDecor->getStructKey()); + iuse->set(val); + } } instsToRemove.add(inst); } @@ -529,8 +506,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { inst->removeAndDeallocate(); } - - stripTempDecorations(func); return primalFunc; } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2d5261b63..944df2c81 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -498,34 +498,6 @@ struct DiffUnzipPass } } - void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) - { - if (as(inst)) - { - SLANG_RELEASE_ASSERT(as(inst->getParent())); - auto lastParam = as(inst->getParent())->getLastParam(); - builder->setInsertAfter(lastParam); - } - else - { - builder->setInsertBefore(inst); - } - } - - void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) - { - if (as(inst)) - { - SLANG_RELEASE_ASSERT(as(inst->getParent())); - auto lastParam = as(inst->getParent())->getLastParam(); - builder->setInsertAfter(lastParam); - } - else - { - builder->setInsertAfter(inst); - } - } - void processIndexedFwdBlock(IRBlock* fwdBlock) { if (!isBlockIndexed(fwdBlock)) @@ -794,6 +766,7 @@ struct DiffUnzipPass auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + primalBuilder->markInstAsPrimal(primalVal); SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); @@ -1377,6 +1350,11 @@ struct DiffUnzipPass // Remove insts that were split. for (auto inst : splitInsts) { + if (!isDifferentiableType(diffTypeContext, inst->getDataType())) + { + inst->replaceUsesWith(lookupPrimalInst(inst)); + } + // Consistency check. for (auto use = inst->firstUse; use; use = use->nextUse) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 97cdb644e..b630b798d 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -563,6 +563,30 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* return false; } +bool canInstBeStored(IRInst* inst) +{ + if (as(inst->getDataType())) + return true; + + switch (inst->getDataType()->getOp()) + { + case kIROp_StructType: + case kIROp_OptionalType: + case kIROp_TupleType: + case kIROp_ArrayType: + case kIROp_DifferentialPairType: + case kIROp_InterfaceType: + case kIROp_AnyValueType: + case kIROp_ClassType: + case kIROp_FloatType: + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } +} + struct AutoDiffPass : public InstPassBase { DiagnosticSink* getSink() diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index f57fb2974..fa01d50ae 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -303,6 +303,8 @@ bool isBackwardDifferentiableFunc(IRInst* func); bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); +bool canInstBeStored(IRInst* inst); + inline bool isRelevantDifferentialPair(IRType* type) { if (as(type)) diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index c750b2d3d..1ee94e67e 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -353,6 +353,17 @@ public: auto loop = as(block->getTerminator()); if (!loop) continue; + bool hasBackEdge = false; + for (auto use = loop->getTargetBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() != loop) + { + hasBackEdge = true; + break; + } + } + if (!hasBackEdge) + continue; if (loop->findDecoration() || loop->findDecoration()) { // We are good. diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 96ffb9bb2..7660c9526 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -701,6 +701,61 @@ static LegalVal legalizeRetVal( return LegalVal(); } +static void _addVal(List& rs, const LegalVal& legalVal) +{ + switch (legalVal.flavor) + { + case LegalVal::Flavor::simple: + rs.add(legalVal.getSimple()); + break; + case LegalVal::Flavor::tuple: + for (auto element : legalVal.getTuple()->elements) + _addVal(rs, element.val); + break; + case LegalVal::Flavor::pair: + _addVal(rs, legalVal.getPair()->ordinaryVal); + _addVal(rs, legalVal.getPair()->specialVal); + break; + case LegalVal::Flavor::none: + break; + default: + SLANG_UNEXPECTED("unhandled legalized val flavor"); + } +} + +static LegalVal legalizeUnconditionalBranch( + IRTypeLegalizationContext* context, + ArrayView args, + IRUnconditionalBranch* branchInst) +{ + List newArgs; + for (auto arg : args) + { + switch (arg.flavor) + { + case LegalVal::Flavor::none: + break; + case LegalVal::Flavor::simple: + newArgs.add(arg.getSimple()); + break; + case LegalVal::Flavor::pair: + _addVal(newArgs, arg.getPair()->ordinaryVal); + _addVal(newArgs, arg.getPair()->specialVal); + break; + case LegalVal::Flavor::tuple: + for (auto element : arg.getTuple()->elements) + { + _addVal(newArgs, element.val); + } + break; + default: + SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor."); + } + } + context->builder->emitBranch(branchInst->getTargetBlock(), newArgs.getCount() - 1, newArgs.getBuffer() + 1); + return LegalVal(); +} + static LegalVal legalizeLoad( IRTypeLegalizationContext* context, LegalVal legalPtrVal) @@ -1610,11 +1665,69 @@ static LegalVal legalizeMakeStruct( } } +static LegalVal legalizeDefaultConstruct( + IRTypeLegalizationContext* context, + LegalType legalType) +{ + auto builder = context->builder; + + switch (legalType.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + + case LegalType::Flavor::simple: + { + return LegalVal::simple( + builder->emitDefaultConstruct(legalType.getSimple())); + } + + case LegalType::Flavor::pair: + { + auto pairType = legalType.getPair(); + auto pairInfo = pairType->pairInfo; + LegalType ordinaryType = pairType->ordinaryType; + LegalType specialType = pairType->specialType; + + LegalVal ordinaryVal = legalizeDefaultConstruct( + context, + ordinaryType); + + LegalVal specialVal = legalizeDefaultConstruct( + context, + specialType); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalType::Flavor::tuple: + { + auto tupleType = legalType.getTuple(); + + RefPtr resTupleInfo = new TuplePseudoVal(); + for (auto typeElem : tupleType->elements) + { + auto elemKey = typeElem.key; + TuplePseudoVal::Element resElem; + resElem.key = elemKey; + resElem.val = legalizeDefaultConstruct(context, typeElem.type); + resTupleInfo->elements.add(resElem); + } + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst, LegalType type, - LegalVal const* args) + ArrayView args) { switch (inst->getOp()) { @@ -1647,8 +1760,14 @@ static LegalVal legalizeInst( return legalizeMakeStruct( context, type, - args, + args.getBuffer(), inst->getOperandCount()); + case kIROp_DefaultConstruct: + return legalizeDefaultConstruct( + context, + type); + case kIROp_unconditionalBranch: + return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst); case kIROp_undefined: return LegalVal(); case kIROp_GpuForeach: @@ -1896,7 +2015,7 @@ static LegalVal legalizeInst( context, inst, legalType, - legalArgs.getBuffer()); + legalArgs.getArrayView()); if (legalVal.flavor == LegalVal::Flavor::simple) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f3c4c2c82..3db036a8d 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -224,7 +224,7 @@ String dumpIRToString(IRInst* root) StringBuilder sb; StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); IRDumpOptions options = {}; -#if 0 +#if 1 options.flags = IRDumpOptions::Flag::DumpDebugIds; #endif dumpIR(root, options, nullptr, &writer); @@ -487,6 +487,34 @@ IROp getSwapSideComparisonOp(IROp op) } } +void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) +{ + if (as(inst)) + { + SLANG_RELEASE_ASSERT(as(inst->getParent())); + auto lastParam = as(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); + } + else + { + builder->setInsertBefore(inst); + } +} + +void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) +{ + if (as(inst)) + { + SLANG_RELEASE_ASSERT(as(inst->getParent())); + auto lastParam = as(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); + } + else + { + builder->setInsertAfter(inst); + } +} + bool isPureFunctionalCall(IRCall* call) { auto callee = getResolvedInstForDecorations(call->getCallee()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0fb26f791..8a12ab895 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -173,6 +173,12 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module); // The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>). IROp getSwapSideComparisonOp(IROp op); +// Set IRBuilder to insert before `inst`. If `inst` is a param, it will insert after the last param. +void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst); + +// Set IRBuilder to insert after `inst`. If `inst` is a param, it will insert after the last param. +void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index dc61a45af..accefc0c9 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5662,9 +5662,9 @@ namespace Slang #if SLANG_ENABLE_IR_BREAK_ALLOC if (context->options.flags & IRDumpOptions::Flag::DumpDebugIds) { - dump(context, "[#"); + dump(context, "{"); dump(context, String(inst->_debugUID)); - dump(context, "]"); + dump(context, "}\t"); } #else SLANG_UNUSED(context); @@ -5691,7 +5691,6 @@ namespace Slang { dump(context, "_"); } - dumpDebugID(context, inst); } static void dumpEncodeString( @@ -5819,6 +5818,7 @@ namespace Slang IRBlock* block) { context->indent--; + dumpDebugID(context, block); dump(context, "block "); dumpID(context, block); @@ -6050,7 +6050,6 @@ namespace Slang } dump(context, opInfo.name); - dumpDebugID(context, inst); dumpInstOperandList(context, inst); } @@ -6068,6 +6067,8 @@ namespace Slang dumpIRDecorations(context, inst); + dumpDebugID(context, inst); + // There are several ops we want to special-case here, // so that they will be more pleasant to look at. // @@ -6204,7 +6205,10 @@ namespace Slang context.options = options; context.sourceManager = sourceManager; - dumpInst(&context, globalVal); + if (globalVal->getOp() == kIROp_Module) + dumpIRModule(&context, globalVal->getModule()); + else + dumpInst(&context, globalVal); writer->write(sb.getBuffer(), sb.getLength()); writer->flush(); -- cgit v1.2.3