summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-05-29 11:14:22 -0700
committerGitHub <noreply@github.com>2024-05-29 11:14:22 -0700
commit83f176ba8a3bae5533470aed6a90663653f894b8 (patch)
tree3e39a674cb4662c946598526f633302f139e14ab /source/slang/slang-emit.cpp
parentc1e34c5a29d99d8a70b4e78313bfd3d539d9206e (diff)
Add options to speedup compilation. (#4240)
* Add options to speedup compilation. * Fix. * Plumb options to DCE pass. * Revert debug change. * Fix regressions. * More optimizations. * more cleanup and fixes. * remove comment. * Fixes. * Another fix. * Fix errors. * Fix errors. * Add comments.
Diffstat (limited to 'source/slang/slang-emit.cpp')
-rw-r--r--source/slang/slang-emit.cpp288
1 files changed, 244 insertions, 44 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 531084434..428c9bf66 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -48,7 +48,6 @@
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-combined-texture-sampler.h"
#include "slang-ir-lower-l-value-cast.h"
-#include "slang-ir-lower-size-of.h"
#include "slang-ir-lower-reinterpret.h"
#include "slang-ir-loop-unroll.h"
#include "slang-ir-legalize-vector-types.h"
@@ -214,6 +213,155 @@ struct LinkingAndOptimizationOptions
CLikeSourceEmitter* sourceEmitter = nullptr;
};
+// To improve the performance of our backend, we will try to avoid running
+// passes related to features not used in the user code.
+// To do so, we will scan the IR module once, and determine which passes are needed
+// based on the instructions used in the IR module.
+// This will allow us to skip running passes that are not needed, without having to
+// run all the passes only to find out that no work is needed.
+// This is especially important for the performance of the backend, as some passes
+// have an initialization cost (such as building reference graphs or DOM trees) that
+// can be expensive.
+//
+struct RequiredLoweringPassSet
+{
+ bool resultType;
+ bool optionalType;
+ bool combinedTextureSamplers;
+ bool reinterpret;
+ bool generics;
+ bool bindExistential;
+ bool autodiff;
+ bool derivativePyBindWrapper;
+ bool bitcast;
+ bool existentialTypeLayout;
+ bool bindingQuery;
+ bool meshOutput;
+ bool higherOrderFunc;
+ bool glslGlobalVar;
+ bool glslSSBO;
+ bool byteAddressBuffer;
+};
+
+// Scan the IR module and determine which lowering/legalization passes are needed based
+// on the instructions we see.
+//
+void calcRequiredLoweringPassSet(RequiredLoweringPassSet& result, CodeGenContext* codeGenContext, IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_ResultType:
+ result.resultType = true;
+ break;
+ case kIROp_OptionalType:
+ result.optionalType = true;
+ break;
+ case kIROp_TextureType:
+ if (!isKhronosTarget(codeGenContext->getTargetReq()))
+ {
+ if (auto texType = as<IRTextureType>(inst))
+ {
+ auto isCombined = texType->getIsCombinedInst();
+ if (auto isCombinedVal = as<IRIntLit>(isCombined))
+ {
+ if (isCombinedVal->getValue() != 0)
+ {
+ result.combinedTextureSamplers = true;
+ }
+ }
+ else
+ {
+ result.combinedTextureSamplers = true;
+ }
+ }
+ }
+ break;
+ case kIROp_PseudoPtrType:
+ case kIROp_BoundInterfaceType:
+ case kIROp_BindExistentialsType:
+ result.generics = true;
+ result.existentialTypeLayout = true;
+ break;
+ case kIROp_GetRegisterIndex:
+ case kIROp_GetRegisterSpace:
+ result.bindingQuery = true;
+ break;
+ case kIROp_BackwardDifferentiate:
+ case kIROp_ForwardDifferentiate:
+ case kIROp_MakeDifferentialPairUserCode:
+ result.autodiff = true;
+ break;
+ case kIROp_VerticesType:
+ case kIROp_IndicesType:
+ case kIROp_PrimitivesType:
+ result.meshOutput = true;
+ break;
+ case kIROp_CreateExistentialObject:
+ case kIROp_MakeExistential:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_WrapExistential:
+ case kIROp_LookupWitness:
+ result.generics = true;
+ break;
+ case kIROp_Specialize:
+ {
+ auto specInst = as<IRSpecialize>(inst);
+ if (!findAnyTargetIntrinsicDecoration(getResolvedInstForDecorations(specInst)))
+ result.generics = true;
+ }
+ break;
+ case kIROp_Reinterpret:
+ result.reinterpret = true;
+ break;
+ case kIROp_BitCast:
+ result.bitcast = true;
+ break;
+ case kIROp_AutoPyBindCudaDecoration:
+ result.derivativePyBindWrapper = true;
+ break;
+ case kIROp_Param:
+ if (as<IRFuncType>(inst->getDataType()))
+ result.higherOrderFunc = true;
+ break;
+ case kIROp_GlobalInputDecoration:
+ case kIROp_GlobalOutputDecoration:
+ case kIROp_GetWorkGroupSize:
+ result.glslGlobalVar = true;
+ break;
+ case kIROp_BindExistentialSlotsDecoration:
+ result.bindExistential = true;
+ result.generics = true;
+ result.existentialTypeLayout = true;
+ break;
+ case kIROp_GLSLShaderStorageBufferType:
+ result.glslSSBO = true;
+ break;
+ case kIROp_ByteAddressBufferLoad:
+ case kIROp_ByteAddressBufferStore:
+ case kIROp_HLSLRWByteAddressBufferType:
+ case kIROp_HLSLByteAddressBufferType:
+ result.byteAddressBuffer = true;
+ break;
+ }
+ if (!result.generics || !result.existentialTypeLayout)
+ {
+ // If any instruction has an interface type, we need to run
+ // the generics lowering pass.
+ auto type = inst->getDataType();
+ if (type && type->getOp() == kIROp_InterfaceType)
+ {
+ result.generics = true;
+ result.existentialTypeLayout = true;
+ }
+ }
+ for (auto child : inst->getDecorationsAndChildren())
+ {
+ calcRequiredLoweringPassSet(result, codeGenContext, child);
+ }
+}
+
Result linkAndOptimizeIR(
CodeGenContext* codeGenContext,
LinkingAndOptimizationOptions const& options,
@@ -252,10 +400,14 @@ Result linkAndOptimizeIR(
// un-specialized IR.
dumpIRIfEnabled(codeGenContext, irModule, "POST IR VALIDATION");
- if(!isKhronosTarget(targetRequest))
+ // Scan the IR module and determine which lowering/legalization passes are needed.
+ RequiredLoweringPassSet requiredLoweringPassSet = {};
+ calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst());
+
+ if(!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO)
lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink);
- if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations())
+ if (requiredLoweringPassSet.glslGlobalVar)
translateGLSLGlobalVar(codeGenContext, irModule);
// Replace any global constants with their values.
@@ -274,7 +426,8 @@ Result linkAndOptimizeIR(
// shader parameters for those slots, to be wired up to
// use sites.
//
- bindExistentialSlots(irModule, sink);
+ if (requiredLoweringPassSet.bindExistential)
+ bindExistentialSlots(irModule, sink);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS BOUND");
#endif
@@ -354,7 +507,8 @@ Result linkAndOptimizeIR(
break;
}
- lowerOptionalType(irModule, sink);
+ if (requiredLoweringPassSet.optionalType)
+ lowerOptionalType(irModule, sink);
switch (target)
{
@@ -370,7 +524,8 @@ Result linkAndOptimizeIR(
}
// Lower `Result<T,E>` types into ordinary struct types.
- lowerResultType(irModule, sink);
+ if (requiredLoweringPassSet.resultType)
+ lowerResultType(irModule, sink);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED");
@@ -382,7 +537,9 @@ Result linkAndOptimizeIR(
IRSimplificationOptions defaultIRSimplificationOptions = IRSimplificationOptions::getDefault(targetProgram);
IRSimplificationOptions fastIRSimplificationOptions = IRSimplificationOptions::getFast(targetProgram);
+ IRDeadCodeEliminationOptions deadCodeEliminationOptions = IRDeadCodeEliminationOptions();
fastIRSimplificationOptions.minimalOptimization = defaultIRSimplificationOptions.minimalOptimization;
+ deadCodeEliminationOptions.useFastAnalysis = fastIRSimplificationOptions.minimalOptimization;
simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink);
@@ -403,7 +560,8 @@ Result linkAndOptimizeIR(
fuseCallsToSaturatedCooperation(irModule);
// Generate any requested derivative wrappers
- generateDerivativeWrappers(irModule, sink);
+ if (requiredLoweringPassSet.derivativePyBindWrapper)
+ generateDerivativeWrappers(irModule, sink);
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
@@ -436,8 +594,11 @@ Result linkAndOptimizeIR(
return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
- applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink());
- eliminateDeadCode(irModule);
+ if (changed)
+ {
+ applySparseConditionalConstantPropagation(irModule, codeGenContext->getSink());
+ }
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
validateIRModuleIfEnabled(codeGenContext, irModule);
@@ -446,10 +607,13 @@ Result linkAndOptimizeIR(
performMandatoryEarlyInlining(irModule);
// Unroll loops.
- if (codeGenContext->getSink()->getErrorCount() == 0)
+ if (!fastIRSimplificationOptions.minimalOptimization)
{
- if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink()))
- return SLANG_FAIL;
+ if (codeGenContext->getSink()->getErrorCount() == 0)
+ {
+ if (!unrollLoopsInModule(targetProgram, irModule, codeGenContext->getSink()))
+ return SLANG_FAIL;
+ }
}
// Few of our targets support higher order functions, and
@@ -457,23 +621,30 @@ Result linkAndOptimizeIR(
// which do.
// Specialize away these parameters
// TODO: We should implement a proper defunctionalization pass
- if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations())
+ if (requiredLoweringPassSet.higherOrderFunc)
changed |= specializeHigherOrderParameters(codeGenContext, irModule);
- dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
- enableIRValidationAtInsert();
- changed |= processAutodiffCalls(targetProgram, irModule, sink);
- disableIRValidationAtInsert();
- dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
+ if (requiredLoweringPassSet.autodiff)
+ {
+ dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
+ enableIRValidationAtInsert();
+ changed |= processAutodiffCalls(targetProgram, irModule, sink);
+ disableIRValidationAtInsert();
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
+ }
if (!changed)
break;
}
- finalizeAutoDiffPass(targetProgram, irModule);
+ if (requiredLoweringPassSet.autodiff)
+ finalizeAutoDiffPass(targetProgram, irModule);
finalizeSpecialization(irModule);
+ requiredLoweringPassSet = {};
+ calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst());
+
switch (target)
{
case CodeGenTarget::PyTorchCppBinding:
@@ -491,7 +662,7 @@ Result linkAndOptimizeIR(
break;
}
- if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations())
+ if (targetProgram->getOptionSet().shouldRunNonEssentialValidation())
checkForRecursiveTypes(irModule, sink);
if (sink->getErrorCount() != 0)
@@ -507,7 +678,9 @@ Result linkAndOptimizeIR(
SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink));
}
- lowerReinterpret(targetProgram, irModule, sink);
+ if (requiredLoweringPassSet.reinterpret)
+ lowerReinterpret(targetProgram, irModule, sink);
+
if (sink->getErrorCount() != 0)
return SLANG_FAIL;
@@ -517,10 +690,18 @@ Result linkAndOptimizeIR(
// but are not used for dynamic dispatch, unpin them so we don't
// do unnecessary work to lower them.
unpinWitnessTables(irModule);
+
+ if (fastIRSimplificationOptions.minimalOptimization)
+ {
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
+ }
+ else
+ {
+ simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink);
+ }
- simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink);
-
- if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc))
+ if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) &&
+ targetProgram->getOptionSet().shouldRunNonEssentialValidation())
{
// We could fail because (perhaps, somehow) end up with getStringHash that the operand is not a string literal
SLANG_RETURN_ON_FAIL(checkGetStringHashInsts(irModule, sink));
@@ -530,7 +711,10 @@ Result linkAndOptimizeIR(
// generics / interface types to ordinary functions and types using
// function pointers.
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-LOWER-GENERICS");
- lowerGenerics(targetProgram, irModule, sink);
+ if (requiredLoweringPassSet.generics)
+ lowerGenerics(targetProgram, irModule, sink);
+ else
+ cleanupGenerics(targetProgram, irModule, sink);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS");
if (sink->getErrorCount() != 0)
@@ -547,7 +731,14 @@ Result linkAndOptimizeIR(
// up downstream passes like type legalization, so we
// will run a DCE pass to clean up after the specialization.
//
- simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink);
+ if (fastIRSimplificationOptions.minimalOptimization)
+ {
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
+ }
+ else
+ {
+ simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink);
+ }
validateIRModuleIfEnabled(codeGenContext, irModule);
@@ -569,11 +760,15 @@ Result linkAndOptimizeIR(
case CodeGenTarget::Metal:
case CodeGenTarget::MetalLib:
case CodeGenTarget::MetalLibAssembly:
- lowerCombinedTextureSamplers(irModule, sink);
+ if (requiredLoweringPassSet.combinedTextureSamplers)
+ lowerCombinedTextureSamplers(irModule, sink);
break;
}
- addUserTypeHintDecorations(irModule);
+ if (codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::VulkanEmitReflection))
+ {
+ addUserTypeHintDecorations(irModule);
+ }
// We don't need the legalize pass for C/C++ based types
if(options.shouldLegalizeExistentialAndResourceTypes )
@@ -603,10 +798,13 @@ Result linkAndOptimizeIR(
// we need to replace it with just an `X`, after which we
// will have (more) legal shader code.
//
- legalizeExistentialTypeLayout(
- targetProgram,
- irModule,
- sink);
+ if (requiredLoweringPassSet.existentialTypeLayout)
+ {
+ legalizeExistentialTypeLayout(
+ targetProgram,
+ irModule,
+ sink);
+ }
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "EXISTENTIALS LEGALIZED");
@@ -652,7 +850,10 @@ Result linkAndOptimizeIR(
// to see if we can clean up any temporaries created by legalization.
// (e.g., things that used to be aggregated might now be split up,
// so that we can work with the individual fields).
- simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink);
+ if (fastIRSimplificationOptions.minimalOptimization)
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
+ else
+ simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA");
@@ -678,7 +879,6 @@ Result linkAndOptimizeIR(
{
specializeArrayParameters(codeGenContext, irModule);
}
- eliminateDeadCode(irModule);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION");
@@ -713,6 +913,7 @@ Result linkAndOptimizeIR(
// of aggregate types from/to byte-address buffers into
// stores of individual scalar or vector values.
//
+ if (requiredLoweringPassSet.byteAddressBuffer)
{
ByteAddressBufferLegalizationOptions byteAddressBufferOptions;
@@ -948,7 +1149,7 @@ Result linkAndOptimizeIR(
//
// We run DCE pass again to clean things up.
//
- eliminateDeadCode(irModule);
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
if (isKhronosTarget(targetRequest))
{
@@ -965,14 +1166,16 @@ Result linkAndOptimizeIR(
// Lower the `getRegisterIndex` and `getRegisterSpace` intrinsics.
//
- lowerBindingQueries(irModule, sink);
+ if (requiredLoweringPassSet.bindingQuery)
+ lowerBindingQueries(irModule, sink);
// For some small improvement in type safety we represent these as opaque
// structs instead of regular arrays.
//
// If any have survived this far, change them back to regular (decorated)
// arrays that the emitters can deal with.
- legalizeMeshOutputTypes(irModule);
+ if (requiredLoweringPassSet.meshOutput)
+ legalizeMeshOutputTypes(irModule);
if (options.shouldLegalizeExistentialAndResourceTypes)
{
@@ -997,20 +1200,17 @@ Result linkAndOptimizeIR(
rcpWOfPositionInput(irModule);
}
- // Lower sizeof/alignof
-
- lowerSizeOfLike(targetProgram, irModule, sink);
-
// Lower all bit_cast operations on complex types into leaf-level
// bit_cast on basic types.
- lowerBitCast(targetProgram, irModule, sink);
+ if (requiredLoweringPassSet.bitcast)
+ lowerBitCast(targetProgram, irModule, sink);
bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly();
if (emitSpirvDirectly)
{
performIntrinsicFunctionInlining(irModule);
- eliminateDeadCode(irModule);
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
}
eliminateMultiLevelBreak(irModule);
@@ -1076,7 +1276,7 @@ Result linkAndOptimizeIR(
// For now we are avoiding that problem by simply *not* emitting live-range
// information when we fix variable scoping later on.
- // Depending on the target, certain things that were represented as
+ // Depending on the target, certain things that were represented ass
// single IR instructions will need to be emitted with the help of
// function declaratons in output high-level code.
//