diff options
| author | Yong He <yonghe@outlook.com> | 2025-02-23 10:31:05 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-23 10:31:05 -0800 |
| commit | 51ad07d1fbffd41c758eba172aa77ebba3204924 (patch) | |
| tree | fadd788714c4ad37830846b0274d56b5ae1eff56 /source/slang/slang-ir-link.cpp | |
| parent | 0101e5ab59a1678ed7212913c3880edfaf039537 (diff) | |
Improve performance when compiling small shaders. (#6396)
Improve performance when compiling small shaders.
Avoid copying witness table entries that are not getting used during linking.
Avoid copying auto-diff related decorations and derivative functions during linking, if the user modules doesn't use autodiff.
Cache operator overload resolution results on global session, so each new Session doesn't need to repetitively run through overload resolution from scratch.
Diffstat (limited to 'source/slang/slang-ir-link.cpp')
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 390 |
1 files changed, 340 insertions, 50 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 53bbbda9e..125ab71d0 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -52,15 +52,28 @@ struct IRSharedSpecContext // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. - typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary; + typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary; SymbolDictionary symbols; + Dictionary<ImmutableHashedString, bool> isImportedSymbol; + + bool useAutodiff = false; + IRBuilder builderStorage; // The "global" specialization environment. IRSpecEnv globalEnv; }; +void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv); + +struct WitnessTableCloneInfo : RefObject +{ + IRWitnessTable* clonedTable; + IRWitnessTable* originalTable; + Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries; +}; + struct IRSpecContextBase { IRSharedSpecContext* shared; @@ -69,7 +82,27 @@ struct IRSpecContextBase IRModule* getModule() { return getShared()->module; } - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } + List<IRModule*> irModules; + + HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys; + List<RefPtr<WitnessTableCloneInfo>> witnessTables; + + IRSpecSymbol* findSymbols(UnownedStringSlice mangledName) + { + ImmutableHashedString hashedName(mangledName); + RefPtr<IRSpecSymbol> symbol; + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + for (auto m : irModules) + { + for (auto inst : m->findSymbolByMangledName(hashedName)) + insertGlobalValueSymbol(shared, inst); + } + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + shared->symbols[hashedName] = nullptr; + return nullptr; + } // The current specialization environment to use. IRSpecEnv* env = nullptr; @@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst // an `Add()` call. // context->getEnv()->clonedValues[originalValue] = clonedValue; + + switch (clonedValue->getOp()) + { + case kIROp_LookupWitness: + + // If `originalVal` represents a witness table entry key, add the key + // to witnessTableEntryWorkList. + context->deferredWitnessTableEntryKeys.add( + getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey())); + break; + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + if (context->getShared()->useAutodiff) + { + if (auto key = as<IRStructKey>(clonedValue->getOperand(0))) + context->deferredWitnessTableEntryKeys.add(getMangledName(key)); + } + break; + } } // Information on values to use when registering a cloned value @@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin return cloneInst(context, builder, originalInst, originalInst); } +bool isAutoDiffDecoration(IRInst* decor) +{ + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: + case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + default: + return false; + } +} + /// Clone any decorations from `originalValue` onto `clonedValue` void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue) { @@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o SLANG_UNUSED(context); for (auto originalDecoration : originalValue->getDecorations()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration)) + continue; cloneInst(context, builder, originalDecoration); } @@ -185,6 +262,8 @@ void cloneDecorationsAndChildren( SLANG_UNUSED(context); for (auto originalItem : originalValue->getDecorationsAndChildren()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem)) + continue; cloneInst(context, builder, originalItem); } @@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) cloneDecorationsAndChildren(this, clonedValue, originalValue); addHoistableInst(builder, clonedValue); - return clonedValue; } break; @@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst( { default: break; - + case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: + if (!context->getShared()->useAutodiff) + break; + [[fallthrough]]; case kIROp_HLSLExportDecoration: case kIROp_BindExistentialSlotsDecoration: case kIROp_LayoutDecoration: case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: - case kIROp_ForwardDerivativeDecoration: - case kIROp_UserDefinedBackwardDerivativeDecoration: - case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: case kIROp_NonCopyableTypeDecoration: case kIROp_DynamicDispatchWitnessDecoration: @@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl( return clonedVal; } +bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table) +{ + for (auto decor : table->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_HLSLExportDecoration: + case kIROp_KeepAliveDecoration: + return true; + } + } + + auto conformanceType = getResolvedInstForDecorations(table->getConformanceType()); + + for (auto decor : conformanceType->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ComInterfaceDecoration: + return true; + case kIROp_KnownBuiltinDecoration: + { + auto name = as<IRKnownBuiltinDecoration>(decor)->getName(); + if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr")) + return context->getShared()->useAutodiff; + break; + } + default: + break; + } + } + + return false; +} IRWitnessTable* cloneWitnessTableImpl( IRSpecContextBase* context, @@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl( bool registerValue = true) { IRWitnessTable* clonedTable = dstTable; + IRType* clonedBaseType = nullptr; if (!clonedTable) { - auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); } - cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); + else + { + clonedBaseType = (IRType*)clonedTable->getConformanceType(); + } + if (registerValue) + registerClonedValue(context, clonedTable, originalValues); + + // Set up an IR builder for inserting into the witness table + IRBuilder builderStorage = *context->builder; + IRBuilder* entryBuilder = &builderStorage; + entryBuilder->setInsertInto(clonedTable); + + // Clone decorations first + for (auto decoration : originalTable->getDecorations()) + { + cloneInst(context, entryBuilder, decoration); + } + cloneExtraDecorations(context, clonedTable, originalValues); + + RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo(); + witnessInfo->clonedTable = clonedTable; + witnessInfo->originalTable = originalTable; + + bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable); + + // Clone only the witness table entries that are actually used + for (auto child : originalTable->getDecorationsAndChildren()) + { + if (auto entry = as<IRWitnessTableEntry>(child)) + { + if (!shouldDeepClone) + { + // Skip witness table entries during the first pass, + // and just add them to the deferred work list. + witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry); + continue; + } + } + // Clone any non-entry children as is + cloneInst(context, entryBuilder, child); + } + context->witnessTables.add(witnessInfo); + return clonedTable; } @@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon( } else { + if (oi->getOp() == kIROp_DifferentiableTypeAnnotation && + !context->getShared()->useAutodiff) + continue; cloneInst(context, builder, oi); } } @@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice()); + if (!sym) { String hashedName = getHashedName(mangledName.getUnownedSlice()); - - if (!context->getSymbols().tryGetValue(hashedName, sym)) + sym = context->findSymbols(hashedName.getUnownedSlice()); + if (!sym) { SLANG_UNEXPECTED("no matching IR symbol"); return nullptr; @@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint( if (!clonedFunc) { SLANG_UNEXPECTED("expected entry point to be a function"); - return nullptr; + UNREACHABLE_RETURN(nullptr); } if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context) return context->getShared()->targetReq->getTargetCaps(); } -/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`. +/// Get the most appropriate ("best") capability requirements for `inVal` based on the +/// `targetCaps`. static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps) { IRInst* val = getResolvedInstForDecorations(inVal); @@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage( // with the same mangled name as `originalVal` and try // to pick the "best" one for our target. - auto mangledName = String(originalLinkage->getMangledName()); - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + auto mangledName = originalLinkage->getMangledName(); + IRSpecSymbol* sym = context->findSymbols(mangledName); + if (!sym) { if (!originalVal) return nullptr; @@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv) { sharedContext->symbols.add(mangledName, sym); } + + if (as<IRImportDecoration>(linkage)) + sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true); + else + sharedContext->isImportedSymbol.set(mangledName, false); } void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule) @@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer( IRSpecContext* context, Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted) { - // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed - // and to allow translation of atomic_uint into SPIRV + // Atomic_uint definitions needs to become a storage buffer to follow + // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV IRBuilder builder = *context->builder; @@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer( instToSwitch->setFullType(storageBuffer); // All references to a atomic_uint need to be an element ref. to emulate storage buffer - // usage All function calls must be inlined since storage buffers cannot pass as parameters - // to atomic methods + // usage All function calls must be inlined since storage buffers cannot pass as + // parameters to atomic methods for (auto& i : bindingToInstList.second) { int64_t currOffset = @@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, { case kIROp_GLSLAtomicUintType: { - // atomic_uint are supported by GLSL->VK through converting to a different type - // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK; - // this means that to get SPIR-V to work we must convert the type ourselves to - // an equivlent representation (storage buffer); the added benifit is that then - // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL - // concept, but storageBuffer->RWBuffer is and HLSL concept + // atomic_uint are supported by GLSL->VK through converting to a different + // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by + // SPIR-V->VK; this means that to get SPIR-V to work we must convert the + // type ourselves to an equivlent representation (storage buffer); the added + // benifit is that then HLSL is possible to emit as a target as well since + // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and + // HLSL concept auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout(); auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1)); assert(layoutVal != nullptr); @@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted); } +bool isDiffPairType(IRInst* type) +{ + for (;;) + { + auto type1 = (IRType*)unwrapAttributedType(type); + auto type2 = unwrapArray(type1); + if (type2 == type) + break; + type = type2; + } + return as<IRDifferentialPairTypeBase>(type) != nullptr; +} + +bool doesModuleUseAutodiff(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Call: + if (auto callee = getResolvedInstForDecorations(inst->getOperand(0))) + { + switch (callee->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + return true; + } + } + return false; + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetPrimal: + case kIROp_DifferentialPtrPairGetDifferential: + return true; + case kIROp_StructField: + return isDiffPairType(as<IRStructField>(inst)->getFieldType()); + case kIROp_Param: + return isDiffPairType(inst->getDataType()); + default: + for (auto child : inst->getChildren()) + { + bool isImported = false; + for (auto decor : child->getDecorations()) + { + if (as<IRImportDecoration>(decor)) + { + isImported = true; + break; + } + else if (as<IRAutoPyBindCudaDecoration>(decor)) + { + return true; + } + else if (as<IRAutoPyBindExportInfoDecoration>(decor)) + { + return true; + } + } + if (isImported) + continue; + for (auto decor : child->getDecorations()) + { + if (isAutoDiffDecoration(decor)) + return true; + } + if (doesModuleUseAutodiff(child)) + return true; + } + return false; + } +} + +void cloneUsedWitnessTableEntries(IRSpecContext* context) +{ + bool changed = true; + while (changed) + { + changed = false; + for (Index i = 0; i < context->witnessTables.getCount(); i++) + { + auto table = context->witnessTables[i].get(); + ShortList<UnownedStringSlice> entriesToRemove; + for (auto entry : table->deferredEntries) + { + if (context->deferredWitnessTableEntryKeys.contains(entry.first)) + { + IRBuilder builder(table->clonedTable); + builder.setInsertInto(table->clonedTable); + auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount(); + cloneInst(context, &builder, entry.second); + entriesToRemove.add(entry.first); + if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount()) + changed = true; + } + } + for (auto entry : entriesToRemove) + { + table->deferredEntries.remove(entry); + } + } + } +} + LinkedIR linkIR(CodeGenContext* codeGenContext) { SLANG_PROFILE; @@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) state->target = target; state->targetReq = targetReq; - + auto& irModules = stateStorage.contextStorage.irModules; auto sharedContext = state->getSharedContext(); initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq); state->irModule = sharedContext->module; - // We need to be able to look up IR definitions for any symbols in // modules that the program depends on (transitively). To // accelerate lookup, we will create a symbol table for looking // up IR definitions by their mangled name. // - - List<IRModule*> irModules; - - // Link the core modules. - auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules; - for (auto& m : coreModules) - irModules.add(m->getIRModule()); + auto globalSession = static_cast<Session*>(linkage->getGlobalSession()); + List<IRModule*> builtinModules; + for (auto& m : globalSession->coreModules) + builtinModules.add(m->getIRModule()); // Link modules in the program. - program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); }); - - // Add any modules that were loaded as libraries - for (IRModule* irModule : irModules) - { - insertGlobalValueSymbols(sharedContext, irModule); - } + program->enumerateIRModules( + [&](IRModule* module) + { + if (module->getName() == globalSession->glslModuleName) + builtinModules.add(module); + else + irModules.add(module); + }); - // We will also insert the IR global symbols from the IR module + // We will also consider the IR global symbols from the IR module // attached to the `TargetProgram`, since this module is // responsible for associating layout information to those // global symbols via decorations. // auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout(); - insertGlobalValueSymbols(sharedContext, irModuleForLayout); + if (irModuleForLayout) + irModules.add(irModuleForLayout); + + Index userModuleCount = irModules.getCount(); + irModules.addRange(builtinModules); + ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount); + + // Check if any user module uses auto-diff, if so we will need to link + // additional witnesses and decorations. + for (IRModule* irModule : userModules) + { + if (sharedContext->useAutodiff) + break; + sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst()); + } auto context = state->getContext(); // Combine all of the contents of IRGlobalHashedStringLiterals { StringSlicePool pool(StringSlicePool::Style::Empty); - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { findGlobalHashedStringLiterals(irModule, pool); } @@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // instructions in all the input modules. // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto inst : irModule->getGlobalInsts()) { @@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // We need to copy over exported symbols, // and any global parameters if preserve-params option is set. if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) || - as<IRDifferentiableTypeAnnotation>(inst)) + sharedContext->useAutodiff && + (as<IRDifferentiableTypeAnnotation>(inst) || + inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr)) { auto cloned = cloneValue(context, inst); if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) } } + // In previous steps, we have skipped cloning the witness table entries, and + // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys + // for on-demand cloning. Now we will use the deferred keys to clone the witness table + // entries that are referenced. + cloneUsedWitnessTableEntries(context); + // It is possible that metadata has been attached to the input modules // themselves, which should be copied over to the output module. // @@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // `[assumedWaveSize(...)]` decoration might require that all specified // values match exactly). // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto decoration : irModule->getModuleInst()->getDecorations()) { |
