diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-29 11:14:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-29 11:14:22 -0700 |
| commit | 83f176ba8a3bae5533470aed6a90663653f894b8 (patch) | |
| tree | 3e39a674cb4662c946598526f633302f139e14ab /source/slang | |
| parent | c1e34c5a29d99d8a70b4e78313bfd3d539d9206e (diff) | |
Add options to speedup compilation. (#4240)
* Add options to speedup compilation.
* Fix.
* Plumb options to DCE pass.
* Revert debug change.
* Fix regressions.
* More optimizations.
* more cleanup and fixes.
* remove comment.
* Fixes.
* Another fix.
* Fix errors.
* Fix errors.
* Add comments.
Diffstat (limited to 'source/slang')
22 files changed, 492 insertions, 308 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index e1dc8a59a..2b2de85bd 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -150,7 +150,7 @@ struct ByteAddressBuffer { case hlsl: __intrinsic_asm ".Load2"; default: - return __byteAddressBufferLoad<uint2>(this, location, __alignOf_intrinsic<uint2>()); + return __byteAddressBufferLoad<uint2>(this, location, __naturalStrideOf<uint2>()); } } @@ -192,7 +192,7 @@ struct ByteAddressBuffer { case hlsl: __intrinsic_asm ".Load3"; default: - return __byteAddressBufferLoad<uint3>(this, location, __alignOf_intrinsic<uint3>()); + return __byteAddressBufferLoad<uint3>(this, location, __naturalStrideOf<uint3>()); } } @@ -234,7 +234,7 @@ struct ByteAddressBuffer { case hlsl: __intrinsic_asm ".Load4"; default: - return __byteAddressBufferLoad<uint4>(this, location, __alignOf_intrinsic<uint4>()); + return __byteAddressBufferLoad<uint4>(this, location, __naturalStrideOf<uint4>()); } } @@ -259,7 +259,7 @@ struct ByteAddressBuffer [ForceInline] T LoadAligned<T>(int location) { - return __byteAddressBufferLoad<T>(this, location, __alignOf_intrinsic<T>()); + return __byteAddressBufferLoad<T>(this, location, __naturalStrideOf<T>()); } }; @@ -3758,7 +3758,7 @@ struct $(item.name) { case hlsl: __intrinsic_asm ".Load2"; default: - return __byteAddressBufferLoad<uint2>(this, location, __alignOf_intrinsic<uint2>()); + return __byteAddressBufferLoad<uint2>(this, location, __naturalStrideOf<uint2>()); } } @@ -3800,7 +3800,7 @@ struct $(item.name) { case hlsl: __intrinsic_asm ".Load3"; default: - return __byteAddressBufferLoad<uint3>(this, location, __alignOf_intrinsic<uint3>()); + return __byteAddressBufferLoad<uint3>(this, location, __naturalStrideOf<uint3>()); } } @@ -3842,7 +3842,7 @@ struct $(item.name) { case hlsl: __intrinsic_asm ".Load4"; default: - return __byteAddressBufferLoad<uint4>(this, location, __alignOf_intrinsic<uint4>()); + return __byteAddressBufferLoad<uint4>(this, location, __naturalStrideOf<uint4>()); } } @@ -3870,7 +3870,7 @@ struct $(item.name) [require(cpp_cuda_glsl_hlsl_spirv, byteaddressbuffer_rw)] T LoadAligned<T>(int location) { - return __byteAddressBufferLoad<T>(this, location, __alignOf_intrinsic<T>()); + return __byteAddressBufferLoad<T>(this, location, __naturalStrideOf<T>()); } ${{{{ @@ -4763,7 +4763,7 @@ ${{{{ { case hlsl: __intrinsic_asm ".Store2"; default: - __byteAddressBufferStore(this, address, __alignOf_intrinsic<uint2>(), value); + __byteAddressBufferStore(this, address, __naturalStrideOf<uint2>(), value); } } @@ -4800,7 +4800,7 @@ ${{{{ { case hlsl: __intrinsic_asm ".Store3"; default: - __byteAddressBufferStore(this, address, __alignOf_intrinsic<uint3>(), value); + __byteAddressBufferStore(this, address, __naturalStrideOf<uint3>(), value); } } @@ -4837,7 +4837,7 @@ ${{{{ { case hlsl: __intrinsic_asm ".Store4"; default: - __byteAddressBufferStore(this, address, __alignOf_intrinsic<uint4>(), value); + __byteAddressBufferStore(this, address, __naturalStrideOf<uint4>(), value); } } @@ -4856,7 +4856,7 @@ ${{{{ [ForceInline] void StoreAligned<T>(int offset, T value) { - __byteAddressBufferStore(this, offset, __alignOf_intrinsic<T>(), value); + __byteAddressBufferStore(this, offset, __naturalStrideOf<T>(), value); } }; diff --git a/source/slang/slang-compiler-options.h b/source/slang/slang-compiler-options.h index 25fbbc407..221d0189f 100644 --- a/source/slang/slang-compiler-options.h +++ b/source/slang/slang-compiler-options.h @@ -366,6 +366,11 @@ namespace Slang return getBoolOption(CompilerOptionName::MinimumSlangOptimization); } + bool shouldRunNonEssentialValidation() + { + return !getBoolOption(CompilerOptionName::DisableNonEssentialValidations); + } + FloatingPointMode getFloatingPointMode() { return getEnumOption<FloatingPointMode>(CompilerOptionName::FloatingPointMode); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 531084434..428c9bf66 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -48,7 +48,6 @@ #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-l-value-cast.h" -#include "slang-ir-lower-size-of.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-legalize-vector-types.h" @@ -214,6 +213,155 @@ struct LinkingAndOptimizationOptions CLikeSourceEmitter* sourceEmitter = nullptr; }; +// To improve the performance of our backend, we will try to avoid running +// passes related to features not used in the user code. +// To do so, we will scan the IR module once, and determine which passes are needed +// based on the instructions used in the IR module. +// This will allow us to skip running passes that are not needed, without having to +// run all the passes only to find out that no work is needed. +// This is especially important for the performance of the backend, as some passes +// have an initialization cost (such as building reference graphs or DOM trees) that +// can be expensive. +// +struct RequiredLoweringPassSet +{ + bool resultType; + bool optionalType; + bool combinedTextureSamplers; + bool reinterpret; + bool generics; + bool bindExistential; + bool autodiff; + bool derivativePyBindWrapper; + bool bitcast; + bool existentialTypeLayout; + bool bindingQuery; + bool meshOutput; + bool higherOrderFunc; + bool glslGlobalVar; + bool glslSSBO; + bool byteAddressBuffer; +}; + +// Scan the IR module and determine which lowering/legalization passes are needed based +// on the instructions we see. +// +void calcRequiredLoweringPassSet(RequiredLoweringPassSet& result, CodeGenContext* codeGenContext, IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_ResultType: + result.resultType = true; + break; + case kIROp_OptionalType: + result.optionalType = true; + break; + case kIROp_TextureType: + if (!isKhronosTarget(codeGenContext->getTargetReq())) + { + if (auto texType = as<IRTextureType>(inst)) + { + auto isCombined = texType->getIsCombinedInst(); + if (auto isCombinedVal = as<IRIntLit>(isCombined)) + { + if (isCombinedVal->getValue() != 0) + { + result.combinedTextureSamplers = true; + } + } + else + { + result.combinedTextureSamplers = true; + } + } + } + break; + case kIROp_PseudoPtrType: + case kIROp_BoundInterfaceType: + case kIROp_BindExistentialsType: + result.generics = true; + result.existentialTypeLayout = true; + break; + case kIROp_GetRegisterIndex: + case kIROp_GetRegisterSpace: + result.bindingQuery = true; + break; + case kIROp_BackwardDifferentiate: + case kIROp_ForwardDifferentiate: + case kIROp_MakeDifferentialPairUserCode: + result.autodiff = true; + break; + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + result.meshOutput = true; + break; + case kIROp_CreateExistentialObject: + case kIROp_MakeExistential: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_WrapExistential: + case kIROp_LookupWitness: + result.generics = true; + break; + case kIROp_Specialize: + { + auto specInst = as<IRSpecialize>(inst); + if (!findAnyTargetIntrinsicDecoration(getResolvedInstForDecorations(specInst))) + result.generics = true; + } + break; + case kIROp_Reinterpret: + result.reinterpret = true; + break; + case kIROp_BitCast: + result.bitcast = true; + break; + case kIROp_AutoPyBindCudaDecoration: + result.derivativePyBindWrapper = true; + break; + case kIROp_Param: + if (as<IRFuncType>(inst->getDataType())) + result.higherOrderFunc = true; + break; + case kIROp_GlobalInputDecoration: + case kIROp_GlobalOutputDecoration: + case kIROp_GetWorkGroupSize: + result.glslGlobalVar = true; + break; + case kIROp_BindExistentialSlotsDecoration: + result.bindExistential = true; + result.generics = true; + result.existentialTypeLayout = true; + break; + case kIROp_GLSLShaderStorageBufferType: + result.glslSSBO = true; + break; + case kIROp_ByteAddressBufferLoad: + case kIROp_ByteAddressBufferStore: + case kIROp_HLSLRWByteAddressBufferType: + case kIROp_HLSLByteAddressBufferType: + result.byteAddressBuffer = true; + break; + } + if (!result.generics || !result.existentialTypeLayout) + { + // If any instruction has an interface type, we need to run + // the generics lowering pass. + auto type = inst->getDataType(); + if (type && type->getOp() == kIROp_InterfaceType) + { + result.generics = true; + result.existentialTypeLayout = true; + } + } + for (auto child : inst->getDecorationsAndChildren()) + { + calcRequiredLoweringPassSet(result, codeGenContext, child); + } +} + Result linkAndOptimizeIR( CodeGenContext* codeGenContext, LinkingAndOptimizationOptions const& options, @@ -252,10 +400,14 @@ Result linkAndOptimizeIR( // un-specialized IR. dumpIRIfEnabled(codeGenContext, irModule, "POST IR VALIDATION"); - if(!isKhronosTarget(targetRequest)) + // Scan the IR module and determine which lowering/legalization passes are needed. + RequiredLoweringPassSet requiredLoweringPassSet = {}; + calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); + + if(!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO) lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink); - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (requiredLoweringPassSet.glslGlobalVar) translateGLSLGlobalVar(codeGenContext, irModule); // Replace any global constants with their values. @@ -274,7 +426,8 @@ Result linkAndOptimizeIR( // shader parameters for those slots, to be wired up to // use sites. // - bindExistentialSlots(irModule, sink); + if (requiredLoweringPassSet.bindExistential) + bindExistentialSlots(irModule, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS BOUND"); #endif @@ -354,7 +507,8 @@ Result linkAndOptimizeIR( break; } - lowerOptionalType(irModule, sink); + if (requiredLoweringPassSet.optionalType) + lowerOptionalType(irModule, sink); switch (target) { @@ -370,7 +524,8 @@ Result linkAndOptimizeIR( } // Lower `Result<T,E>` types into ordinary struct types. - lowerResultType(irModule, sink); + if (requiredLoweringPassSet.resultType) + lowerResultType(irModule, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED"); @@ -382,7 +537,9 @@ Result linkAndOptimizeIR( IRSimplificationOptions defaultIRSimplificationOptions = IRSimplificationOptions::getDefault(targetProgram); IRSimplificationOptions fastIRSimplificationOptions = IRSimplificationOptions::getFast(targetProgram); + IRDeadCodeEliminationOptions deadCodeEliminationOptions = IRDeadCodeEliminationOptions(); fastIRSimplificationOptions.minimalOptimization = defaultIRSimplificationOptions.minimalOptimization; + deadCodeEliminationOptions.useFastAnalysis = fastIRSimplificationOptions.minimalOptimization; simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); @@ -403,7 +560,8 @@ Result linkAndOptimizeIR( fuseCallsToSaturatedCooperation(irModule); // Generate any requested derivative wrappers - generateDerivativeWrappers(irModule, sink); + if (requiredLoweringPassSet.derivativePyBindWrapper) + generateDerivativeWrappers(irModule, sink); // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would @@ -436,8 +594,11 @@ Result linkAndOptimizeIR( return SLANG_FAIL; dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); - applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink()); - eliminateDeadCode(irModule); + if (changed) + { + applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink()); + } + eliminateDeadCode(irModule, deadCodeEliminationOptions); validateIRModuleIfEnabled(codeGenContext, irModule); @@ -446,10 +607,13 @@ Result linkAndOptimizeIR( performMandatoryEarlyInlining(irModule); // Unroll loops. - if (codeGenContext->getSink()->getErrorCount() == 0) + if (!fastIRSimplificationOptions.minimalOptimization) { - if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink())) - return SLANG_FAIL; + if (codeGenContext->getSink()->getErrorCount() == 0) + { + if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink())) + return SLANG_FAIL; + } } // Few of our targets support higher order functions, and @@ -457,23 +621,30 @@ Result linkAndOptimizeIR( // which do. // Specialize away these parameters // TODO: We should implement a proper defunctionalization pass - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (requiredLoweringPassSet.higherOrderFunc) changed |= specializeHigherOrderParameters(codeGenContext, irModule); - dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - enableIRValidationAtInsert(); - changed |= processAutodiffCalls(targetProgram, irModule, sink); - disableIRValidationAtInsert(); - dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + if (requiredLoweringPassSet.autodiff) + { + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); + enableIRValidationAtInsert(); + changed |= processAutodiffCalls(targetProgram, irModule, sink); + disableIRValidationAtInsert(); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + } if (!changed) break; } - finalizeAutoDiffPass(targetProgram, irModule); + if (requiredLoweringPassSet.autodiff) + finalizeAutoDiffPass(targetProgram, irModule); finalizeSpecialization(irModule); + requiredLoweringPassSet = {}; + calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); + switch (target) { case CodeGenTarget::PyTorchCppBinding: @@ -491,7 +662,7 @@ Result linkAndOptimizeIR( break; } - if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) + if (targetProgram->getOptionSet().shouldRunNonEssentialValidation()) checkForRecursiveTypes(irModule, sink); if (sink->getErrorCount() != 0) @@ -507,7 +678,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } - lowerReinterpret(targetProgram, irModule, sink); + if (requiredLoweringPassSet.reinterpret) + lowerReinterpret(targetProgram, irModule, sink); + if (sink->getErrorCount() != 0) return SLANG_FAIL; @@ -517,10 +690,18 @@ Result linkAndOptimizeIR( // but are not used for dynamic dispatch, unpin them so we don't // do unnecessary work to lower them. unpinWitnessTables(irModule); + + if (fastIRSimplificationOptions.minimalOptimization) + { + eliminateDeadCode(irModule, deadCodeEliminationOptions); + } + else + { + simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); + } - simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); - - if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc)) + if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) && + targetProgram->getOptionSet().shouldRunNonEssentialValidation()) { // We could fail because (perhaps, somehow) end up with getStringHash that the operand is not a string literal SLANG_RETURN_ON_FAIL(checkGetStringHashInsts(irModule, sink)); @@ -530,7 +711,10 @@ Result linkAndOptimizeIR( // generics / interface types to ordinary functions and types using // function pointers. dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-LOWER-GENERICS"); - lowerGenerics(targetProgram, irModule, sink); + if (requiredLoweringPassSet.generics) + lowerGenerics(targetProgram, irModule, sink); + else + cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); if (sink->getErrorCount() != 0) @@ -547,7 +731,14 @@ Result linkAndOptimizeIR( // up downstream passes like type legalization, so we // will run a DCE pass to clean up after the specialization. // - simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + if (fastIRSimplificationOptions.minimalOptimization) + { + eliminateDeadCode(irModule, deadCodeEliminationOptions); + } + else + { + simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + } validateIRModuleIfEnabled(codeGenContext, irModule); @@ -569,11 +760,15 @@ Result linkAndOptimizeIR( case CodeGenTarget::Metal: case CodeGenTarget::MetalLib: case CodeGenTarget::MetalLibAssembly: - lowerCombinedTextureSamplers(irModule, sink); + if (requiredLoweringPassSet.combinedTextureSamplers) + lowerCombinedTextureSamplers(irModule, sink); break; } - addUserTypeHintDecorations(irModule); + if (codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::VulkanEmitReflection)) + { + addUserTypeHintDecorations(irModule); + } // We don't need the legalize pass for C/C++ based types if(options.shouldLegalizeExistentialAndResourceTypes ) @@ -603,10 +798,13 @@ Result linkAndOptimizeIR( // we need to replace it with just an `X`, after which we // will have (more) legal shader code. // - legalizeExistentialTypeLayout( - targetProgram, - irModule, - sink); + if (requiredLoweringPassSet.existentialTypeLayout) + { + legalizeExistentialTypeLayout( + targetProgram, + irModule, + sink); + } #if 0 dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS LEGALIZED"); @@ -652,7 +850,10 @@ Result linkAndOptimizeIR( // to see if we can clean up any temporaries created by legalization. // (e.g., things that used to be aggregated might now be split up, // so that we can work with the individual fields). - simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); + if (fastIRSimplificationOptions.minimalOptimization) + eliminateDeadCode(irModule, deadCodeEliminationOptions); + else + simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA"); @@ -678,7 +879,6 @@ Result linkAndOptimizeIR( { specializeArrayParameters(codeGenContext, irModule); } - eliminateDeadCode(irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION"); @@ -713,6 +913,7 @@ Result linkAndOptimizeIR( // of aggregate types from/to byte-address buffers into // stores of individual scalar or vector values. // + if (requiredLoweringPassSet.byteAddressBuffer) { ByteAddressBufferLegalizationOptions byteAddressBufferOptions; @@ -948,7 +1149,7 @@ Result linkAndOptimizeIR( // // We run DCE pass again to clean things up. // - eliminateDeadCode(irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); if (isKhronosTarget(targetRequest)) { @@ -965,14 +1166,16 @@ Result linkAndOptimizeIR( // Lower the `getRegisterIndex` and `getRegisterSpace` intrinsics. // - lowerBindingQueries(irModule, sink); + if (requiredLoweringPassSet.bindingQuery) + lowerBindingQueries(irModule, sink); // For some small improvement in type safety we represent these as opaque // structs instead of regular arrays. // // If any have survived this far, change them back to regular (decorated) // arrays that the emitters can deal with. - legalizeMeshOutputTypes(irModule); + if (requiredLoweringPassSet.meshOutput) + legalizeMeshOutputTypes(irModule); if (options.shouldLegalizeExistentialAndResourceTypes) { @@ -997,20 +1200,17 @@ Result linkAndOptimizeIR( rcpWOfPositionInput(irModule); } - // Lower sizeof/alignof - - lowerSizeOfLike(targetProgram, irModule, sink); - // Lower all bit_cast operations on complex types into leaf-level // bit_cast on basic types. - lowerBitCast(targetProgram, irModule, sink); + if (requiredLoweringPassSet.bitcast) + lowerBitCast(targetProgram, irModule, sink); bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly(); if (emitSpirvDirectly) { performIntrinsicFunctionInlining(irModule); - eliminateDeadCode(irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); } eliminateMultiLevelBreak(irModule); @@ -1076,7 +1276,7 @@ Result linkAndOptimizeIR( // For now we are avoiding that problem by simply *not* emitting live-range // information when we fix variable scoping later on. - // Depending on the target, certain things that were represented as + // Depending on the target, certain things that were represented ass // single IR instructions will need to be emitted with the help of // function declaratons in output high-level code. // diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index ec4c749fa..4eefeabd5 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -307,7 +307,11 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o // First, if `inst` is an instruction that might have some effects // when it is executed, then we should keep it around. // - if (inst->mightHaveSideEffects(SideEffectAnalysisOptions::UseDominanceTree)) + SideEffectAnalysisOptions sideEffectOptions = options.useFastAnalysis + ? SideEffectAnalysisOptions::None + : SideEffectAnalysisOptions::UseDominanceTree; + + if (inst->mightHaveSideEffects(sideEffectOptions)) { return true; } diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h index d8819e042..d52ec817e 100644 --- a/source/slang/slang-ir-dce.h +++ b/source/slang/slang-ir-dce.h @@ -11,6 +11,7 @@ namespace Slang { bool keepExportsAlive = false; bool keepLayoutsAlive = false; + bool useFastAnalysis = false; }; /// Eliminate "dead" code from the given IR module. diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index a2ccc1ed7..0d5cb1c70 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -29,6 +29,8 @@ struct InliningPassBase /// The module that we are optimizing/transforming IRModule* m_module = nullptr; + HashSet<IRInst*>* m_modifiedFuncs = nullptr; + /// Initialize an inlining pass to operate on the given `module` InliningPassBase(IRModule* module) : m_module(module) @@ -157,6 +159,11 @@ struct InliningPassBase // given call site, we hand off the a worker routine // that does the meat of the work. // + if (m_modifiedFuncs) + { + if (auto parentFunc = getParentFunc(call)) + m_modifiedFuncs->add(parentFunc); + } inlineCallSite(callSite); return true; } @@ -698,12 +705,13 @@ struct MandatoryEarlyInliningPass : InliningPassBase }; -void performMandatoryEarlyInlining(IRModule* module) +bool performMandatoryEarlyInlining(IRModule* module, HashSet<IRInst*>* modifiedFuncs) { SLANG_PROFILE; MandatoryEarlyInliningPass pass(module); - pass.considerAllCallSites(); + pass.m_modifiedFuncs = modifiedFuncs; + return pass.considerAllCallSites(); } namespace { // anonymous diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index e12f5f0d5..1b37a28ae 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -2,6 +2,7 @@ #pragma once #include "../../slang-com-helper.h" +#include "core/slang-basic.h" namespace Slang { @@ -10,12 +11,13 @@ namespace Slang struct IRGlobalValueWithCode; class DiagnosticSink; class TargetProgram; + struct IRInst; /// Any call to a function that takes or returns a string/RefType parameter is inlined Result performTypeInlining(IRModule* module, DiagnosticSink* sink); /// Inline any call sites to functions marked `[unsafeForceInlineEarly]` - void performMandatoryEarlyInlining(IRModule* module); + bool performMandatoryEarlyInlining(IRModule* module, HashSet<IRInst*>* modifiedFuncs = nullptr); /// Inline any call sites to functions marked `[ForceInline]` void performForceInlining(IRModule* module); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1729085be..8d1880366 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3681,6 +3681,13 @@ public: { return emitCallInst(type, func, args.getCount(), args.getBuffer()); } + IRCall* emitCallInst( + IRType* type, + IRInst* func, + ArrayView<IRInst*> args) + { + return emitCallInst(type, func, args.getCount(), args.getBuffer()); + } IRInst* emitTryCallInst( IRType* type, diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index f36e02066..6bde87765 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -212,7 +212,7 @@ struct LegalCallBuilder IRCall* m_call = nullptr; /// The legalized arguments for the call - List<IRInst*> m_args; + ShortList<IRInst*> m_args; /// Add a logical argument to the call (which may map to zero or mmore actual arguments) void addArg( @@ -463,7 +463,7 @@ private: resultType, m_call->getCallee(), m_args.getCount(), - m_args.getBuffer()); + m_args.getArrayView().getBuffer()); } }; @@ -706,7 +706,7 @@ static LegalVal legalizeRetVal( return LegalVal(); } -static void _addVal(List<IRInst*>& rs, const LegalVal& legalVal) +static void _addVal(ShortList<IRInst*>& rs, const LegalVal& legalVal) { switch (legalVal.flavor) { @@ -733,7 +733,7 @@ static LegalVal legalizeUnconditionalBranch( ArrayView<LegalVal> args, IRUnconditionalBranch* branchInst) { - List<IRInst*> newArgs; + ShortList<IRInst*> newArgs; for (auto arg : args) { switch (arg.flavor) @@ -757,7 +757,7 @@ static LegalVal legalizeUnconditionalBranch( SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor."); } } - context->builder->emitIntrinsicInst(nullptr, branchInst->getOp(), newArgs.getCount(), newArgs.getBuffer()); + context->builder->emitIntrinsicInst(nullptr, branchInst->getOp(), newArgs.getCount(), newArgs.getArrayView().getBuffer()); return LegalVal(); } @@ -861,7 +861,7 @@ static LegalVal legalizeDebugVar(IRTypeLegalizationContext* context, LegalType t static LegalVal legalizeDebugValue(IRTypeLegalizationContext* context, LegalVal debugVar, LegalVal debugValue, IRDebugValue* originalInst) { // For now we just discard any special part and keep the ordinary part. - List<IRInst*> accessChain; + ShortList<IRInst*> accessChain; for (UInt i = 0; i < originalInst->getAccessChainCount(); i++) { accessChain.add(originalInst->getAccessChain(i)); @@ -873,7 +873,7 @@ static LegalVal legalizeDebugValue(IRTypeLegalizationContext* context, LegalVal context->builder->emitDebugValue( debugVar.getSimple(), debugValue.getSimple(), - accessChain.getArrayView())); + accessChain.getArrayView().arrayView)); case LegalType::Flavor::none: return LegalVal(); case LegalType::Flavor::pair: @@ -2205,7 +2205,7 @@ static LegalVal legalizeInst( // value of each, and collect them in an array for subsequent use. // auto argCount = inst->getOperandCount(); - List<LegalVal> legalArgs; + ShortList<LegalVal> legalArgs; // // Along the way we will also note whether there were any operands // with non-simple legalized values. @@ -2277,7 +2277,7 @@ static LegalVal legalizeInst( context, inst, legalType, - legalArgs.getArrayView()); + legalArgs.getArrayView().arrayView); if (legalVal.flavor == LegalVal::Flavor::simple) { diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index dd85487f5..198804500 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -182,7 +182,8 @@ namespace Slang void stripWrapExistential(IRModule* module) { - auto& workList = *module->getContainerPool().getList<IRInst>(); + InstWorkList workList(module); + workList.add(module->getModuleInst()); for (Index i = 0; i < workList.getCount(); i++) { @@ -229,7 +230,6 @@ namespace Slang // and used to create a tuple representing the existential value. augmentMakeExistentialInsts(module); - lowerGenericFunctions(&sharedContext); if (sink->getErrorCount() != 0) return; @@ -271,4 +271,28 @@ namespace Slang // We should remove them now. stripWrapExistential(module); } + + void cleanupGenerics(TargetProgram* program, IRModule* module, DiagnosticSink* sink) + { + SharedGenericsLoweringContext sharedContext(module); + sharedContext.targetProgram = program; + sharedContext.sink = sink; + + specializeRTTIObjects(&sharedContext, sink); + + lowerTuples(module, sink); + if (sink->getErrorCount() != 0) + return; + + generateAnyValueMarshallingFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + + // At this point, we should no longer need to care any `WrapExistential` insts, + // although they could still exist in the IR in order to call generic stdlib functions, + // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). + // We should remove them now. + stripWrapExistential(module); + } + } // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.h b/source/slang/slang-ir-lower-generics.h index 2ec250803..238138111 100644 --- a/source/slang/slang-ir-lower-generics.h +++ b/source/slang/slang-ir-lower-generics.h @@ -16,4 +16,7 @@ namespace Slang IRModule* module,
DiagnosticSink* sink);
+ // Clean up any generic-related IR insts that are no longer needed. Called when
+ // it has been determined that no more dynamic dispatch code will be generated.
+ void cleanupGenerics(TargetProgram* targetReq, IRModule* module, DiagnosticSink* sink);
}
diff --git a/source/slang/slang-ir-lower-size-of.cpp b/source/slang/slang-ir-lower-size-of.cpp deleted file mode 100644 index 5e7e26824..000000000 --- a/source/slang/slang-ir-lower-size-of.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "slang-ir-lower-size-of.h" - -#include "slang-ir.h" -#include "slang-ir-insts.h" - -#include "slang-ir-layout.h" - -namespace Slang -{ - -struct SizeOfLikeLoweringContext -{ - void _addToWorkList(IRInst* inst) - { - if (!findOuterGeneric(inst) && !m_workList.contains(inst)) - { - m_workList.add(inst); - } - } - - void _processInst(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_AlignOf: - case kIROp_SizeOf: - _processSizeOfLike(inst); - break; - default: - break; - } - } - - void processModule() - { - _addToWorkList(m_module->getModuleInst()); - - while (m_workList.getCount() != 0) - { - IRInst* inst = m_workList.getLast(); - m_workList.removeLast(); - - _processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - _addToWorkList(child); - } - } - } - - void _processSizeOfLike(IRInst* sizeOfLikeInst) - { - auto typeOperand = as<IRType>(sizeOfLikeInst->getOperand(0)); - - IRSizeAndAlignment sizeAndAlignment; - - if (SLANG_FAILED(getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), typeOperand, &sizeAndAlignment))) - { - // Output a diagnostic failure - if(sizeOfLikeInst->getOp() == kIROp_AlignOf) - { - m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToAlignOf, typeOperand); - } - else - { - m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToSizeOf, typeOperand); - } - - return; - } - - IRBuilder builder(m_module); - - const auto value = (sizeOfLikeInst->getOp() == kIROp_AlignOf) ? - sizeAndAlignment.alignment : - sizeAndAlignment.size; - - auto valueInst = builder.getIntValue(sizeOfLikeInst->getDataType(), value); - - // Replace all uses of sizeOfLikeInst with the value - sizeOfLikeInst->replaceUsesWith(valueInst); - // We don't need the instruction any more - sizeOfLikeInst->removeAndDeallocate(); - } - - SizeOfLikeLoweringContext(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink): - m_module(module), - m_targetProgram(targetProgram), - m_sink(sink) - { - } - - TargetProgram* m_targetProgram; - DiagnosticSink* m_sink; - IRModule* m_module; - OrderedHashSet<IRInst*> m_workList; -}; - -void lowerSizeOfLike(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink) -{ - SizeOfLikeLoweringContext context(targetProgram, module, sink); - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-size-of.h b/source/slang/slang-ir-lower-size-of.h deleted file mode 100644 index 90ab40fbd..000000000 --- a/source/slang/slang-ir-lower-size-of.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef SLANG_IR_LOWER_SIZE_OF_H -#define SLANG_IR_LOWER_SIZE_OF_H - -// This defines an IR pass that lowers sizeof/alignof. - -namespace Slang -{ - -struct IRModule; -class TargetProgram; -class DiagnosticSink; - -void lowerSizeOfLike(TargetProgram* target, IRModule* module, DiagnosticSink* sink); - -} // namespace Slang - -#endif diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 16e440b32..db867fe7d 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -11,13 +11,15 @@ struct PeepholeContext : InstPassBase { PeepholeContext(IRModule* inModule) : InstPassBase(inModule) - {} + { + } bool changed = false; FloatingPointMode floatingPointMode = FloatingPointMode::Precise; bool removeOldInst = true; bool isInGeneric = false; bool isPrelinking = false; + bool useFastAnalysis = false; TargetProgram* targetProgram; @@ -251,16 +253,19 @@ struct PeepholeContext : InstPassBase switch (inst->getOp()) { case kIROp_AlignOf: - // Fold all calls to alignOf<T>() that returns a simple integer value. - if (inst->getDataType()->getOp() == kIROp_IntType) + case kIROp_SizeOf: { if (!targetProgram) break; // Save the alignment information and exit early if it is invalid IRSizeAndAlignment sizeAlignment; - auto alignOfInst = as<IRAlignOf>(inst); - auto baseType = alignOfInst->getBaseOp()->getDataType(); + IRType* baseType = nullptr; + if (auto t = as<IRType>(inst->getOperand(0))) + baseType = t; + else + baseType = inst->getOperand(0)->getDataType(); + if (SLANG_FAILED(getNaturalSizeAndAlignment(targetProgram->getOptionSet(), baseType, &sizeAlignment))) break; if (sizeAlignment.size == 0) @@ -268,8 +273,12 @@ struct PeepholeContext : InstPassBase IRBuilder builder(module); builder.setInsertBefore(inst); - auto stride = builder.getIntValue(inst->getDataType(), sizeAlignment.getStride()); - inst->replaceUsesWith(stride); + IRInst* resultVal = nullptr; + if (inst->getOp() == kIROp_AlignOf) + resultVal = builder.getIntValue(inst->getDataType(), sizeAlignment.alignment); + else + resultVal = builder.getIntValue(inst->getDataType(), sizeAlignment.size); + inst->replaceUsesWith(resultVal); maybeRemoveOldInst(inst); changed = true; } @@ -891,7 +900,7 @@ struct PeepholeContext : InstPassBase // Never remove param inst. changed = true; } - else + else if (!useFastAnalysis) { // If argValue is defined locally, // we can replace only if argVal dominates inst. @@ -1114,7 +1123,8 @@ struct PeepholeContext : InstPassBase bool processFunc(IRInst* func) { - func->getModule()->invalidateAllAnalysis(); + if (!useFastAnalysis) + func->getModule()->invalidateAllAnalysis(); bool lastIsInGeneric = isInGeneric; if (!isInGeneric) @@ -1147,6 +1157,9 @@ bool peepholeOptimize(TargetProgram* target, IRModule* module, PeepholeOptimizat PeepholeContext context = PeepholeContext(module); context.targetProgram = target; context.isPrelinking = options.isPrelinking; + context.useFastAnalysis = target + ? target->getOptionSet().getBoolOption(CompilerOptionName::MinimumSlangOptimization) + : true; return context.processModule(); } @@ -1154,6 +1167,9 @@ bool peepholeOptimize(TargetProgram* target, IRInst* func) { PeepholeContext context = PeepholeContext(func->getModule()); context.targetProgram = target; + context.useFastAnalysis = target + ? target->getOptionSet().getBoolOption(CompilerOptionName::MinimumSlangOptimization) + : true; return context.processFunc(func); } @@ -1161,7 +1177,7 @@ bool peepholeOptimizeGlobalScope(TargetProgram* target, IRModule* module) { PeepholeContext context = PeepholeContext(module); context.targetProgram = target; - + context.useFastAnalysis = true; bool result = false; for (;;) { @@ -1183,6 +1199,7 @@ bool tryReplaceInstUsesWithSimplifiedValue(TargetProgram* target, IRModule* modu PeepholeContext context = PeepholeContext(inst->getModule()); context.targetProgram = target; context.removeOldInst = false; + context.useFastAnalysis = true; context.processInst(inst); return context.changed; } diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index d5987db48..419477790 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -1189,8 +1189,11 @@ bool specializeResourceUsage( // and turned into SSA temporaries. Such optimization may enable // the following passes to "see" and specialize more cases. // - simplifyIR(codeGenContext->getTargetProgram(), irModule, - IRSimplificationOptions::getFast(codeGenContext->getTargetProgram())); + if (changed) + { + simplifyIR(codeGenContext->getTargetProgram(), irModule, + IRSimplificationOptions::getFast(codeGenContext->getTargetProgram())); + } result |= changed; } if (unspecializableFuncs.getCount() == 0) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 6137a7158..8fac31a2b 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -56,11 +56,17 @@ struct SpecializationContext SpecializationContext(IRModule* inModule, TargetProgram* target) : workList(*inModule->getContainerPool().getList<IRInst>()) , workListSet(*inModule->getContainerPool().getHashSet<IRInst>()) - , cleanInsts(*module->getContainerPool().getHashSet<IRInst>()) + , cleanInsts(*inModule->getContainerPool().getHashSet<IRInst>()) , module(inModule) , targetProgram(target) { } + ~SpecializationContext() + { + module->getContainerPool().free(&workList); + module->getContainerPool().free(&workListSet); + module->getContainerPool().free(&cleanInsts); + } // An instruction is then fully specialized if and only // if it is in our set. @@ -442,7 +448,7 @@ struct SpecializationContext // that our args are fully specialized/concrete. // UCount argCount = specInst->getArgCount(); - List<IRInst*> args; + ShortList<IRInst*> args; for (UIndex ii = 0; ii < argCount; ii++) args.add(specInst->getArg(ii)); @@ -454,7 +460,7 @@ struct SpecializationContext builder.getTypeKind(), specDiffFunc->getBase()->getDataType(), argCount, - args.getBuffer()); + args.getArrayView().getBuffer()); // Specialize the custom derivative function with the original arguments. builder.setInsertBefore(specInst); @@ -462,7 +468,7 @@ struct SpecializationContext (IRType*)newDiffFuncType, specDiffFunc->getBase(), argCount, - args.getBuffer()); + args.getArrayView().getBuffer()); // Add the new spec insts to the list so they get specialized with // the usual logic. @@ -722,6 +728,7 @@ struct SpecializationContext builder.setInsertInto(moduleInst); auto dictInst = builder.emitIntrinsicInst(nullptr, dictOp, 0, nullptr); builder.setInsertInto(dictInst); + List<IRInst*> args; for (const auto& [key, value] : dict) { if (!value->parent) @@ -731,10 +738,10 @@ struct SpecializationContext if (!keyVal->parent) goto next; } { - List<IRInst*> args; + args.clear(); args.add(value); args.addRange(key.vals); - builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, args.getCount(), args.getBuffer()); + builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, (UInt)args.getCount(), args.getBuffer()); } next:; } @@ -955,11 +962,11 @@ struct SpecializationContext IRBuilder builder(module); builder.setInsertBefore(inst); - List<IRInst*> args; + ShortList<IRInst*> args; args.add(wrapExistential->getWrappedValue()); for (UInt i = 1; i < inst->getArgCount(); i++) args.add(inst->getArg(i)); - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = wrapExistential->getSlotOperandCount(); for (UInt ii = 0; ii < slotOperandCount; ++ii) { @@ -977,9 +984,9 @@ struct SpecializationContext innerResultType = builder.getPtrType(elementType); } auto newCallee = getNewSpecializedBufferLoadCallee(inst->getCallee(), sbType, innerResultType); - auto newCall = builder.emitCallInst(innerResultType, newCallee, args); + auto newCall = builder.emitCallInst(innerResultType, newCallee, (UInt)args.getCount(), args.getArrayView().getBuffer()); auto newWrapExistential = builder.emitWrapExistential( - resultType, newCall, slotOperandCount, slotOperands.getBuffer()); + resultType, newCall, slotOperandCount, slotOperands.getArrayView().getBuffer()); inst->replaceUsesWith(newWrapExistential); workList.remove(inst); inst->removeAndDeallocate(); @@ -1147,7 +1154,7 @@ struct SpecializationContext // We will start by constructing the argument list for the new call. // argCounter = 0; - List<IRInst*> newArgs; + ShortList<IRInst*> newArgs; for (auto param : calleeFunc->getParams()) { auto arg = inst->getArg(argCounter++); @@ -1193,7 +1200,7 @@ struct SpecializationContext builder->setInsertBefore(inst); auto newCall = builder->emitCallInst( - inst->getFullType(), specializedCallee, newArgs); + inst->getFullType(), specializedCallee, (UInt)newArgs.getCount(), newArgs.getArrayView().getBuffer()); // We will completely replace the old `call` instruction with the // new one, and will go so far as to transfer any decorations @@ -1235,6 +1242,7 @@ struct SpecializationContext } // Test if a type is compile time constant. + HashSet<IRInst*> seenTypeSet; bool isCompileTimeConstantType(IRInst* inst) { // TODO: We probably need/want a more robust test here. @@ -1243,10 +1251,11 @@ struct SpecializationContext if (!isInstFullySpecialized(inst)) return false; - List<IRInst*> localWorkList; - HashSet<IRInst*> processedInsts; + ShortList<IRInst*> localWorkList; + seenTypeSet.clear(); + localWorkList.add(inst); - processedInsts.add(inst); + seenTypeSet.add(inst); while (localWorkList.getCount() != 0) { @@ -1269,10 +1278,8 @@ struct SpecializationContext for (UInt i = 0; i < curInst->getOperandCount(); ++i) { auto operand = curInst->getOperand(i); - if (processedInsts.add(operand)) - { + if (seenTypeSet.add(operand)) localWorkList.add(operand); - } } } return true; @@ -1375,7 +1382,7 @@ struct SpecializationContext // the lists here because we don't yet have a basic // block, or even a function, to insert them into. // - List<IRParam*> newParams; + ShortList<IRParam*, 16> newParams; UInt argCounter = 0; for (auto oldParam : oldFunc->getParams()) { @@ -1481,14 +1488,14 @@ struct SpecializationContext // In order to construct the type of the new function, we // need to extract the types of all its parameters. // - List<IRType*> newParamTypes; + ShortList<IRType*> newParamTypes; for (auto newParam : newParams) { newParamTypes.add(newParam->getFullType()); } IRType* newFuncType = builder->getFuncType( newParamTypes.getCount(), - newParamTypes.getBuffer(), + newParamTypes.getArrayView().getBuffer(), oldFunc->getResultType()); newFunc->setFullType(newFuncType); @@ -1690,7 +1697,7 @@ struct SpecializationContext return false; - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = wrapInst->getSlotOperandCount(); for (UInt ii = 0; ii < slotOperandCount; ++ii) { @@ -1702,7 +1709,7 @@ struct SpecializationContext resultType, newLoadInst, slotOperandCount, - slotOperands.getBuffer()); + slotOperands.getArrayView().getBuffer()); addUsersToWorkList(inst); @@ -1823,7 +1830,7 @@ struct SpecializationContext auto foundFieldType = foundField->getFieldType(); - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); for (UInt ii = 0; ii < slotOperandCount; ++ii) @@ -1840,7 +1847,7 @@ struct SpecializationContext resultType, newGetField, slotOperandCount, - slotOperands.getBuffer()); + slotOperands.getArrayView().getBuffer()); addUsersToWorkList(inst); inst->replaceUsesWith(newWrapExistentialInst); @@ -1913,7 +1920,7 @@ struct SpecializationContext auto foundFieldType = foundField->getFieldType(); - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); for (UInt ii = 0; ii < slotOperandCount; ++ii) @@ -1930,7 +1937,7 @@ struct SpecializationContext resultType, newGetFieldAddr, slotOperandCount, - slotOperands.getBuffer()); + slotOperands.getArrayView().getBuffer()); addUsersToWorkList(inst); inst->replaceUsesWith(newWrapExistentialInst); @@ -1958,7 +1965,7 @@ struct SpecializationContext auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType(); - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = wrapInst->getSlotOperandCount(); for (UInt ii = 0; ii < slotOperandCount; ++ii) @@ -1969,7 +1976,7 @@ struct SpecializationContext auto newGetElement = builder.emitElementExtract(elementType, val, index); auto newWrapExistentialInst = builder.emitWrapExistential( - resultType, newGetElement, slotOperandCount, slotOperands.getBuffer()); + resultType, newGetElement, slotOperandCount, slotOperands.getArrayView().getBuffer()); addUsersToWorkList(inst); inst->replaceUsesWith(newWrapExistentialInst); @@ -1999,7 +2006,7 @@ struct SpecializationContext IRBuilder builder(module); builder.setInsertBefore(inst); - List<IRInst*> slotOperands; + ShortList<IRInst*> slotOperands; UInt slotOperandCount = wrapInst->getSlotOperandCount(); for (UInt ii = 0; ii < slotOperandCount; ++ii) @@ -2011,7 +2018,7 @@ struct SpecializationContext auto newElementAddr = builder.emitElementAddress(elementPtrType, val, index); auto newWrapExistentialInst = builder.emitWrapExistential( - resultType, newElementAddr, slotOperandCount, slotOperands.getBuffer()); + resultType, newElementAddr, slotOperandCount, slotOperands.getArrayView().getBuffer()); addUsersToWorkList(inst); inst->replaceUsesWith(newWrapExistentialInst); diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index a0224cea5..6c02734b5 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -45,6 +45,8 @@ namespace Slang const int kMaxIterations = 8; const int kMaxFuncIterations = 16; int iterationCounter = 0; + IRDeadCodeEliminationOptions dceOptions = IRDeadCodeEliminationOptions(); + dceOptions.useFastAnalysis = options.minimalOptimization; while (changed && iterationCounter < kMaxIterations) { @@ -74,20 +76,19 @@ namespace Slang if (!options.minimalOptimization) funcChanged |= removeRedundancyInFunc(func); funcChanged |= simplifyCFG(func, options.cfgOptions); - eliminateDeadCode(func); - funcChanged |= constructSSA(func); + // Note: we disregard the `changed` state from dead code elimination pass since + // SCCP pass could be generating temporarily evaluated constant values and never actually use them. + // DCE will always remove those nearly generated consts and always returns true here. + eliminateDeadCode(func, dceOptions); + if (funcIterationCount == 0) + funcChanged |= constructSSA(func); changed |= funcChanged; funcIterationCount++; } } - - // Note: we disregard the `changed` state from dead code elimination pass since - // SCCP pass could be generating temporarily evaluated constant values and never actually use them. - // DCE will always remove those nearly generated consts and always returns true here. - eliminateDeadCode(module); - iterationCounter++; } + eliminateDeadCode(module, dceOptions); } void simplifyNonSSAIR(TargetProgram* target, IRModule* module, IRSimplificationOptions options) @@ -95,6 +96,9 @@ namespace Slang bool changed = true; const int kMaxIterations = 8; int iterationCounter = 0; + IRDeadCodeEliminationOptions dceOptions = IRDeadCodeEliminationOptions(); + dceOptions.useFastAnalysis = options.minimalOptimization; + while (changed && iterationCounter < kMaxIterations) { changed = false; @@ -107,7 +111,7 @@ namespace Slang // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. // DCE will always remove those nearly generated consts and always returns true here. - eliminateDeadCode(module); + eliminateDeadCode(module, dceOptions); iterationCounter++; } } @@ -115,6 +119,9 @@ namespace Slang void simplifyFunc(TargetProgram* target, IRGlobalValueWithCode* func, IRSimplificationOptions options, DiagnosticSink* sink) { + IRDeadCodeEliminationOptions dceOptions = IRDeadCodeEliminationOptions(); + dceOptions.useFastAnalysis = options.minimalOptimization; + bool changed = true; const int kMaxIterations = 8; int iterationCounter = 0; @@ -133,7 +140,7 @@ namespace Slang // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. // DCE will always remove those nearly generated consts and always returns true here. - eliminateDeadCode(func); + eliminateDeadCode(func, dceOptions); changed |= constructSSA(func); diff --git a/source/slang/slang-ir-uniformity.cpp b/source/slang/slang-ir-uniformity.cpp index e09b6d89c..6f81e85f0 100644 --- a/source/slang/slang-ir-uniformity.cpp +++ b/source/slang/slang-ir-uniformity.cpp @@ -191,8 +191,9 @@ namespace Slang void propagateNonUniform(IRFunc* root, List<IRInst*>& workList) { - List<IRInst*>& nextWorkList = *module->getContainerPool().getList<IRInst>(); - HashSet<IRInst*>& workListSet = *module->getContainerPool().getHashSet<IRInst>(); + InstWorkList nextWorkList(module); + InstHashSet workListSet(module); + auto addToWorkList = [&](IRInst* inst) { if (workListSet.add(inst)) @@ -404,14 +405,15 @@ namespace Slang addToWorkList(user); } } - workList.swapWith(nextWorkList); + workList.swapWith(nextWorkList.getList()); nextWorkList.clear(); } } void analyzeModule() { - List<IRInst*>& workList = *module->getContainerPool().getList<IRInst>(); + InstWorkList workList(module); + for (auto globalInst : module->getGlobalInsts()) { if (auto code = as<IRGlobalValueWithCode>(globalInst)) @@ -438,7 +440,7 @@ namespace Slang } currentCallee = func; call = nullptr; - propagateNonUniform(func, workList); + propagateNonUniform(func, workList.getList()); } } workList.clear(); @@ -448,7 +450,7 @@ namespace Slang void eliminateAsDynamicUniformInst() { - List<IRInst*>& workList = *module->getContainerPool().getList<IRInst>(); + InstWorkList workList(module); workList.add(module->getModuleInst()); for (Index i = 0; i < workList.getCount(); i++) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 51b8344f6..967db3aff 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1077,11 +1077,6 @@ bool isPureFunctionalCall(IRCall* call, SideEffectAnalysisOptions options) bool isSideEffectFreeFunctionalCall(IRCall* call, SideEffectAnalysisOptions options) { - // If the call has been marked as no-side-effect, we - // will treat it so, by-passing all other checks. - if (call->findDecoration<IRNoSideEffectDecoration>()) - return false; - if (!doesCalleeHaveSideEffect(call->getCallee())) { return areCallArgumentsSideEffectFree(call, options); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index d3a4c8026..3a5384e8d 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2419,6 +2419,7 @@ struct InstWorkList other.pool = nullptr; return *this; } + List<IRInst*>& getList() { return *workList; } IRInst* operator[](Index i) { return (*workList)[i]; } Index getCount() { return workList->getCount(); } IRInst** begin() { return workList->begin(); } @@ -2464,7 +2465,7 @@ struct InstHashSet other.pool = nullptr; return *this; } - + HashSet<IRInst*>& getHashSet() { return *set; } Index getCount() { return set->getCount(); } bool add(IRInst* inst) { return set->add(inst); } bool contains(IRInst* inst) { return set->contains(inst); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f92ecc92a..8a8d235f5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10708,6 +10708,7 @@ RefPtr<IRModule> generateIRForTranslationUnit( auto session = translationUnit->getSession(); auto compileRequest = translationUnit->compileRequest; + Linkage* linkage = compileRequest->getLinkage(); SharedIRGenContext sharedContextStorage( session, @@ -10823,36 +10824,36 @@ RefPtr<IRModule> generateIRForTranslationUnit( // Generate DebugValue insts to store values into debug variables, // if debug symbols are enabled. - if (compileRequest->getLinkage()->m_optionSet.getEnumOption<DebugInfoLevel>(CompilerOptionName::DebugInformation) != DebugInfoLevel::None) + if (context->includeDebugInfo) { insertDebugValueStore(module); } + // Next, attempt to promote local variables to SSA // temporaries and do basic simplifications. // constructSSA(module); - simplifyCFG(module, CFGSimplificationOptions::getDefault()); applySparseConditionalConstantPropagation(module, compileRequest->getSink()); - peepholeOptimize(nullptr, module, PeepholeOptimizationOptions::getPrelinking()); + + bool minimumOptimizations = linkage->m_optionSet.getBoolOption(CompilerOptionName::MinimumSlangOptimization); + if (!minimumOptimizations) + { + simplifyCFG(module, CFGSimplificationOptions::getDefault()); + auto peepholeOptions = PeepholeOptimizationOptions::getPrelinking(); + peepholeOptimize(nullptr, module, peepholeOptions); + } + + IRDeadCodeEliminationOptions dceOptions = IRDeadCodeEliminationOptions(); + dceOptions.keepExportsAlive = true; + dceOptions.keepLayoutsAlive = true; + dceOptions.useFastAnalysis = true; for (auto inst : module->getGlobalInsts()) { if (auto func = as<IRGlobalValueWithCode>(inst)) - eliminateDeadCode(func); + eliminateDeadCode(func, dceOptions); } - // Next, inline calls to any functions that have been - // marked for mandatory "early" inlining. - // - // Note: We performed certain critical simplifications - // above, before this step, so that the body of functions - // subject to mandatory inlining can be simplified ahead - // of time. By simplifying the body before inlining it, - // we can make sure that things like superfluous temporaries - // are eliminated from the callee, and not copied into - // call sites. - // - performMandatoryEarlyInlining(module); // Where possible, move loop condition checks to the end of loops, and wrap // the loop in an 'if(condition)'. @@ -10875,43 +10876,63 @@ RefPtr<IRModule> generateIRForTranslationUnit( invertLoops(module); } - // Next, attempt to promote local variables to SSA - // temporaries and do basic simplifications. + // Next, inline calls to any functions that have been + // marked for mandatory "early" inlining. + // + // Note: We performed certain critical simplifications + // above, before this step, so that the body of functions + // subject to mandatory inlining can be simplified ahead + // of time. By simplifying the body before inlining it, + // we can make sure that things like superfluous temporaries + // are eliminated from the callee, and not copied into + // call sites. // + InstHashSet modifiedFuncs(module); for (;;) { bool changed = false; - performMandatoryEarlyInlining(module); - changed |= constructSSA(module); - simplifyCFG(module, CFGSimplificationOptions::getDefault()); - changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink()); - changed |= peepholeOptimize(nullptr, module, PeepholeOptimizationOptions::getPrelinking()); - for (auto inst : module->getGlobalInsts()) + modifiedFuncs.clear(); + changed = performMandatoryEarlyInlining(module, &modifiedFuncs.getHashSet()); + if (changed) { - if (auto func = as<IRGlobalValueWithCode>(inst)) - eliminateDeadCode(func); + changed = peepholeOptimizeGlobalScope(nullptr, module); + if (!minimumOptimizations) + { + for (auto func : modifiedFuncs.getHashSet()) + { + auto codeInst = as<IRGlobalValueWithCode>(func); + changed |= constructSSA(func); + changed |= applySparseConditionalConstantPropagation(func, compileRequest->getSink()); + changed |= peepholeOptimize(nullptr, func); + changed |= simplifyCFG(codeInst, CFGSimplificationOptions::getFast()); + eliminateDeadCode(func, dceOptions); + } + } } if (!changed) break; } - // Propagate `constexpr`-ness through the dataflow graph (and the - // call graph) based on constraints imposed by different instructions. - propagateConstExpr(module, compileRequest->getSink()); - - // TODO: give error messages if any `undefined` or - // `unreachable` instructions remain. - // Check for using uninitialized out parameters. - checkForUsingUninitializedOutParams(module, compileRequest->getSink()); + if (compileRequest->getLinkage()->m_optionSet.shouldRunNonEssentialValidation()) + { + // Propagate `constexpr`-ness through the dataflow graph (and the + // call graph) based on constraints imposed by different instructions. + propagateConstExpr(module, compileRequest->getSink()); - checkForMissingReturns(module, compileRequest->getSink()); + checkForUsingUninitializedOutParams(module, compileRequest->getSink()); + + // TODO: give error messages if any `undefined` or + // instructions remain. + + checkForMissingReturns(module, compileRequest->getSink()); - // We don't allow recursive types. - checkForRecursiveTypes(module, compileRequest->getSink()); + // We don't allow recursive types. + checkForRecursiveTypes(module, compileRequest->getSink()); - // Check for invalid differentiable function body. - checkAutoDiffUsages(module, compileRequest->getSink()); + // Check for invalid differentiable function body. + checkAutoDiffUsages(module, compileRequest->getSink()); + } // The "mandatory" optimization passes may make use of the // `IRHighLevelDeclDecoration` type to relate IR instructions @@ -10937,8 +10958,6 @@ RefPtr<IRModule> generateIRForTranslationUnit( // IRStripOptions stripOptions; - Linkage* linkage = compileRequest->getLinkage(); - stripOptions.shouldStripNameHints = linkage->m_optionSet.shouldObfuscateCode(); // If we are generating an obfuscated source map, we don't want to strip locs, @@ -10959,15 +10978,14 @@ RefPtr<IRModule> generateIRForTranslationUnit( // pass here, but make sure to set our options so that we don't // eliminate anything that has been marked for export. // - IRDeadCodeEliminationOptions options; - options.keepExportsAlive = true; - eliminateDeadCode(module, options); + eliminateDeadCode(module, dceOptions); - if (linkage->m_optionSet.shouldObfuscateCode()) - { - // The obfuscated source map is stored on the module - obfuscateModuleLocs(module, compileRequest->getSourceManager()); - } + } + + if (linkage->m_optionSet.shouldObfuscateCode()) + { + // The obfuscated source map is stored on the module + obfuscateModuleLocs(module, compileRequest->getSourceManager()); } // TODO: consider doing some more aggressive optimizations diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index d140df4c3..2d3f35557 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -291,6 +291,7 @@ void initCommandOptions(CommandOptions& options) { OptionKind::MatrixLayoutRow,"-matrix-layout-row-major", nullptr, "Set the default matrix layout to row-major."}, { OptionKind::IgnoreCapabilities,"-ignore-capabilities", nullptr, "Do not warn or error if capabilities are violated"}, { OptionKind::MinimumSlangOptimization, "-minimum-slang-optimization", nullptr, "Perform minimum code optimization in Slang to favor compilation time."}, + { OptionKind::DisableNonEssentialValidations, "-disable-non-essential-validations", nullptr, "Disable non-essential IR validations such as use of uninitialized variables."}, { OptionKind::ModuleName, "-module-name", "-module-name <name>", "Set the module name to use when compiling multiple .slang source files into a single module."}, { OptionKind::Output, "-o", "-o <path>", @@ -1683,6 +1684,8 @@ SlangResult OptionsParser::_parse( case OptionKind::VulkanUseGLLayout: case OptionKind::VulkanEmitReflection: case OptionKind::IgnoreCapabilities: + case OptionKind::MinimumSlangOptimization: + case OptionKind::DisableNonEssentialValidations: case OptionKind::DefaultImageFormatUnknown: case OptionKind::Obfuscate: case OptionKind::OutputIncludes: |
