diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-29 11:14:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-29 11:14:22 -0700 |
| commit | 83f176ba8a3bae5533470aed6a90663653f894b8 (patch) | |
| tree | 3e39a674cb4662c946598526f633302f139e14ab /source/slang/slang-emit.cpp | |
| parent | c1e34c5a29d99d8a70b4e78313bfd3d539d9206e (diff) | |
Add options to speedup compilation. (#4240)
* Add options to speedup compilation.
* Fix.
* Plumb options to DCE pass.
* Revert debug change.
* Fix regressions.
* More optimizations.
* more cleanup and fixes.
* remove comment.
* Fixes.
* Another fix.
* Fix errors.
* Fix errors.
* Add comments.
Diffstat (limited to 'source/slang/slang-emit.cpp')
| -rw-r--r-- | source/slang/slang-emit.cpp | 288 |
1 files changed, 244 insertions, 44 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 531084434..428c9bf66 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -48,7 +48,6 @@ #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-l-value-cast.h" -#include "slang-ir-lower-size-of.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-legalize-vector-types.h" @@ -214,6 +213,155 @@ struct LinkingAndOptimizationOptions CLikeSourceEmitter* sourceEmitter = nullptr; }; +// To improve the performance of our backend, we will try to avoid running +// passes related to features not used in the user code. +// To do so, we will scan the IR module once, and determine which passes are needed +// based on the instructions used in the IR module. +// This will allow us to skip running passes that are not needed, without having to +// run all the passes only to find out that no work is needed. +// This is especially important for the performance of the backend, as some passes +// have an initialization cost (such as building reference graphs or DOM trees) that +// can be expensive. +// +struct RequiredLoweringPassSet +{ + bool resultType; + bool optionalType; + bool combinedTextureSamplers; + bool reinterpret; + bool generics; + bool bindExistential; + bool autodiff; + bool derivativePyBindWrapper; + bool bitcast; + bool existentialTypeLayout; + bool bindingQuery; + bool meshOutput; + bool higherOrderFunc; + bool glslGlobalVar; + bool glslSSBO; + bool byteAddressBuffer; +}; + +// Scan the IR module and determine which lowering/legalization passes are needed based +// on the instructions we see. +// +void calcRequiredLoweringPassSet(RequiredLoweringPassSet& result, CodeGenContext* codeGenContext, IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_ResultType: + result.resultType = true; + break; + case kIROp_OptionalType: + result.optionalType = true; + break; + case kIROp_TextureType: + if (!isKhronosTarget(codeGenContext->getTargetReq())) + { + if (auto texType = as<IRTextureType>(inst)) + { + auto isCombined = texType->getIsCombinedInst(); + if (auto isCombinedVal = as<IRIntLit>(isCombined)) + { + if (isCombinedVal->getValue() != 0) + { + result.combinedTextureSamplers = true; + } + } + else + { + result.combinedTextureSamplers = true; + } + } + } + break; + case kIROp_PseudoPtrType: + case kIROp_BoundInterfaceType: + case kIROp_BindExistentialsType: + result.generics = true; + result.existentialTypeLayout = true; + break; + case kIROp_GetRegisterIndex: + case kIROp_GetRegisterSpace: + result.bindingQuery = true; + break; + case kIROp_BackwardDifferentiate: + case kIROp_ForwardDifferentiate: + case kIROp_MakeDifferentialPairUserCode: + result.autodiff = true; + break; + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + result.meshOutput = true; + break; + case kIROp_CreateExistentialObject: + case kIROp_MakeExistential: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_WrapExistential: + case kIROp_LookupWitness: + result.generics = true; + break; + case kIROp_Specialize: + { + auto specInst = as<IRSpecialize>(inst); + if (!findAnyTargetIntrinsicDecoration(getResolvedInstForDecorations(specInst))) + result.generics = true; + } + break; + case kIROp_Reinterpret: + result.reinterpret = true; + break; + case kIROp_BitCast: + result.bitcast = true; + break; + case kIROp_AutoPyBindCudaDecoration: + result.derivativePyBindWrapper = true; + break; + case kIROp_Param: + if (as<IRFuncType>(inst->getDataType())) + result.higherOrderFunc = true; + break; + case kIROp_GlobalInputDecoration: + case kIROp_GlobalOutputDecoration: + case kIROp_GetWorkGroupSize: + result.glslGlobalVar = true; + break; + case kIROp_BindExistentialSlotsDecoration: + result.bindExistential = true; + result.generics = true; + result.existentialTypeLayout = true; + break; + case kIROp_GLSLShaderStorageBufferType: + result.glslSSBO = true; + break; + case kIROp_ByteAddressBufferLoad: + case kIROp_ByteAddressBufferStore: + case kIROp_HLSLRWByteAddressBufferType: + case kIROp_HLSLByteAddressBufferType: + result.byteAddressBuffer = true; + break; + } + if (!result.generics || !result.existentialTypeLayout) + { + // If any instruction has an interface type, we need to run + // the generics lowering pass. + auto type = inst->getDataType(); + if (type && type->getOp() == kIROp_InterfaceType) + { + result.generics = true; + result.existentialTypeLayout = true; + } + } + for (auto child : inst->getDecorationsAndChildren()) + { + calcRequiredLoweringPassSet(result, codeGenContext, child); + } +} + Result linkAndOptimizeIR( CodeGenContext* codeGenContext, LinkingAndOptimizationOptions const& options, @@ -252,10 +400,14 @@ Result linkAndOptimizeIR( // un-specialized IR. dumpIRIfEnabled(codeGenContext, irModule, "POST IR VALIDATION"); - if(!isKhronosTarget(targetRequest)) + // Scan the IR module and determine which lowering/legalization passes are needed. + RequiredLoweringPassSet requiredLoweringPassSet = {}; + calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); + + if(!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO) lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink); - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (requiredLoweringPassSet.glslGlobalVar) translateGLSLGlobalVar(codeGenContext, irModule); // Replace any global constants with their values. @@ -274,7 +426,8 @@ Result linkAndOptimizeIR( // shader parameters for those slots, to be wired up to // use sites. // - bindExistentialSlots(irModule, sink); + if (requiredLoweringPassSet.bindExistential) + bindExistentialSlots(irModule, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS BOUND"); #endif @@ -354,7 +507,8 @@ Result linkAndOptimizeIR( break; } - lowerOptionalType(irModule, sink); + if (requiredLoweringPassSet.optionalType) + lowerOptionalType(irModule, sink); switch (target) { @@ -370,7 +524,8 @@ Result linkAndOptimizeIR( } // Lower `Result<T,E>` types into ordinary struct types. - lowerResultType(irModule, sink); + if (requiredLoweringPassSet.resultType) + lowerResultType(irModule, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED"); @@ -382,7 +537,9 @@ Result linkAndOptimizeIR( IRSimplificationOptions defaultIRSimplificationOptions = IRSimplificationOptions::getDefault(targetProgram); IRSimplificationOptions fastIRSimplificationOptions = IRSimplificationOptions::getFast(targetProgram); + IRDeadCodeEliminationOptions deadCodeEliminationOptions = IRDeadCodeEliminationOptions(); fastIRSimplificationOptions.minimalOptimization = defaultIRSimplificationOptions.minimalOptimization; + deadCodeEliminationOptions.useFastAnalysis = fastIRSimplificationOptions.minimalOptimization; simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); @@ -403,7 +560,8 @@ Result linkAndOptimizeIR( fuseCallsToSaturatedCooperation(irModule); // Generate any requested derivative wrappers - generateDerivativeWrappers(irModule, sink); + if (requiredLoweringPassSet.derivativePyBindWrapper) + generateDerivativeWrappers(irModule, sink); // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would @@ -436,8 +594,11 @@ Result linkAndOptimizeIR( return SLANG_FAIL; dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); - applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink()); - eliminateDeadCode(irModule); + if (changed) + { + applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink()); + } + eliminateDeadCode(irModule, deadCodeEliminationOptions); validateIRModuleIfEnabled(codeGenContext, irModule); @@ -446,10 +607,13 @@ Result linkAndOptimizeIR( performMandatoryEarlyInlining(irModule); // Unroll loops. - if (codeGenContext->getSink()->getErrorCount() == 0) + if (!fastIRSimplificationOptions.minimalOptimization) { - if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink())) - return SLANG_FAIL; + if (codeGenContext->getSink()->getErrorCount() == 0) + { + if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink())) + return SLANG_FAIL; + } } // Few of our targets support higher order functions, and @@ -457,23 +621,30 @@ Result linkAndOptimizeIR( // which do. // Specialize away these parameters // TODO: We should implement a proper defunctionalization pass - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (requiredLoweringPassSet.higherOrderFunc) changed |= specializeHigherOrderParameters(codeGenContext, irModule); - dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - enableIRValidationAtInsert(); - changed |= processAutodiffCalls(targetProgram, irModule, sink); - disableIRValidationAtInsert(); - dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + if (requiredLoweringPassSet.autodiff) + { + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); + enableIRValidationAtInsert(); + changed |= processAutodiffCalls(targetProgram, irModule, sink); + disableIRValidationAtInsert(); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + } if (!changed) break; } - finalizeAutoDiffPass(targetProgram, irModule); + if (requiredLoweringPassSet.autodiff) + finalizeAutoDiffPass(targetProgram, irModule); finalizeSpecialization(irModule); + requiredLoweringPassSet = {}; + calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); + switch (target) { case CodeGenTarget::PyTorchCppBinding: @@ -491,7 +662,7 @@ Result linkAndOptimizeIR( break; } - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (targetProgram->getOptionSet().shouldRunNonEssentialValidation()) checkForRecursiveTypes(irModule, sink); if (sink->getErrorCount() != 0) @@ -507,7 +678,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } - lowerReinterpret(targetProgram, irModule, sink); + if (requiredLoweringPassSet.reinterpret) + lowerReinterpret(targetProgram, irModule, sink); + if (sink->getErrorCount() != 0) return SLANG_FAIL; @@ -517,10 +690,18 @@ Result linkAndOptimizeIR( // but are not used for dynamic dispatch, unpin them so we don't // do unnecessary work to lower them. unpinWitnessTables(irModule); + + if (fastIRSimplificationOptions.minimalOptimization) + { + eliminateDeadCode(irModule, deadCodeEliminationOptions); + } + else + { + simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); + } - simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); - - if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc)) + if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) && + targetProgram->getOptionSet().shouldRunNonEssentialValidation()) { // We could fail because (perhaps, somehow) end up with getStringHash that the operand is not a string literal SLANG_RETURN_ON_FAIL(checkGetStringHashInsts(irModule, sink)); @@ -530,7 +711,10 @@ Result linkAndOptimizeIR( // generics / interface types to ordinary functions and types using // function pointers. dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-LOWER-GENERICS"); - lowerGenerics(targetProgram, irModule, sink); + if (requiredLoweringPassSet.generics) + lowerGenerics(targetProgram, irModule, sink); + else + cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); if (sink->getErrorCount() != 0) @@ -547,7 +731,14 @@ Result linkAndOptimizeIR( // up downstream passes like type legalization, so we // will run a DCE pass to clean up after the specialization. // - simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + if (fastIRSimplificationOptions.minimalOptimization) + { + eliminateDeadCode(irModule, deadCodeEliminationOptions); + } + else + { + simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + } validateIRModuleIfEnabled(codeGenContext, irModule); @@ -569,11 +760,15 @@ Result linkAndOptimizeIR( case CodeGenTarget::Metal: case CodeGenTarget::MetalLib: case CodeGenTarget::MetalLibAssembly: - lowerCombinedTextureSamplers(irModule, sink); + if (requiredLoweringPassSet.combinedTextureSamplers) + lowerCombinedTextureSamplers(irModule, sink); break; } - addUserTypeHintDecorations(irModule); + if (codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::VulkanEmitReflection)) + { + addUserTypeHintDecorations(irModule); + } // We don't need the legalize pass for C/C++ based types if(options.shouldLegalizeExistentialAndResourceTypes ) @@ -603,10 +798,13 @@ Result linkAndOptimizeIR( // we need to replace it with just an `X`, after which we // will have (more) legal shader code. // - legalizeExistentialTypeLayout( - targetProgram, - irModule, - sink); + if (requiredLoweringPassSet.existentialTypeLayout) + { + legalizeExistentialTypeLayout( + targetProgram, + irModule, + sink); + } #if 0 dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS LEGALIZED"); @@ -652,7 +850,10 @@ Result linkAndOptimizeIR( // to see if we can clean up any temporaries created by legalization. // (e.g., things that used to be aggregated might now be split up, // so that we can work with the individual fields). - simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); + if (fastIRSimplificationOptions.minimalOptimization) + eliminateDeadCode(irModule, deadCodeEliminationOptions); + else + simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA"); @@ -678,7 +879,6 @@ Result linkAndOptimizeIR( { specializeArrayParameters(codeGenContext, irModule); } - eliminateDeadCode(irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION"); @@ -713,6 +913,7 @@ Result linkAndOptimizeIR( // of aggregate types from/to byte-address buffers into // stores of individual scalar or vector values. // + if (requiredLoweringPassSet.byteAddressBuffer) { ByteAddressBufferLegalizationOptions byteAddressBufferOptions; @@ -948,7 +1149,7 @@ Result linkAndOptimizeIR( // // We run DCE pass again to clean things up. // - eliminateDeadCode(irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); if (isKhronosTarget(targetRequest)) { @@ -965,14 +1166,16 @@ Result linkAndOptimizeIR( // Lower the `getRegisterIndex` and `getRegisterSpace` intrinsics. // - lowerBindingQueries(irModule, sink); + if (requiredLoweringPassSet.bindingQuery) + lowerBindingQueries(irModule, sink); // For some small improvement in type safety we represent these as opaque // structs instead of regular arrays. // // If any have survived this far, change them back to regular (decorated) // arrays that the emitters can deal with. - legalizeMeshOutputTypes(irModule); + if (requiredLoweringPassSet.meshOutput) + legalizeMeshOutputTypes(irModule); if (options.shouldLegalizeExistentialAndResourceTypes) { @@ -997,20 +1200,17 @@ Result linkAndOptimizeIR( rcpWOfPositionInput(irModule); } - // Lower sizeof/alignof - - lowerSizeOfLike(targetProgram, irModule, sink); - // Lower all bit_cast operations on complex types into leaf-level // bit_cast on basic types. - lowerBitCast(targetProgram, irModule, sink); + if (requiredLoweringPassSet.bitcast) + lowerBitCast(targetProgram, irModule, sink); bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly(); if (emitSpirvDirectly) { performIntrinsicFunctionInlining(irModule); - eliminateDeadCode(irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); } eliminateMultiLevelBreak(irModule); @@ -1076,7 +1276,7 @@ Result linkAndOptimizeIR( // For now we are avoiding that problem by simply *not* emitting live-range // information when we fix variable scoping later on. - // Depending on the target, certain things that were represented as + // Depending on the target, certain things that were represented ass // single IR instructions will need to be emitted with the help of // function declaratons in output high-level code. // |
