diff options
Diffstat (limited to 'source/slang/slang-ir-link.cpp')
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 390 |
1 files changed, 340 insertions, 50 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 53bbbda9e..125ab71d0 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -52,15 +52,28 @@ struct IRSharedSpecContext // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. - typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary; + typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary; SymbolDictionary symbols; + Dictionary<ImmutableHashedString, bool> isImportedSymbol; + + bool useAutodiff = false; + IRBuilder builderStorage; // The "global" specialization environment. IRSpecEnv globalEnv; }; +void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv); + +struct WitnessTableCloneInfo : RefObject +{ + IRWitnessTable* clonedTable; + IRWitnessTable* originalTable; + Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries; +}; + struct IRSpecContextBase { IRSharedSpecContext* shared; @@ -69,7 +82,27 @@ struct IRSpecContextBase IRModule* getModule() { return getShared()->module; } - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } + List<IRModule*> irModules; + + HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys; + List<RefPtr<WitnessTableCloneInfo>> witnessTables; + + IRSpecSymbol* findSymbols(UnownedStringSlice mangledName) + { + ImmutableHashedString hashedName(mangledName); + RefPtr<IRSpecSymbol> symbol; + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + for (auto m : irModules) + { + for (auto inst : m->findSymbolByMangledName(hashedName)) + insertGlobalValueSymbol(shared, inst); + } + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + shared->symbols[hashedName] = nullptr; + return nullptr; + } // The current specialization environment to use. IRSpecEnv* env = nullptr; @@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst // an `Add()` call. // context->getEnv()->clonedValues[originalValue] = clonedValue; + + switch (clonedValue->getOp()) + { + case kIROp_LookupWitness: + + // If `originalVal` represents a witness table entry key, add the key + // to witnessTableEntryWorkList. + context->deferredWitnessTableEntryKeys.add( + getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey())); + break; + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + if (context->getShared()->useAutodiff) + { + if (auto key = as<IRStructKey>(clonedValue->getOperand(0))) + context->deferredWitnessTableEntryKeys.add(getMangledName(key)); + } + break; + } } // Information on values to use when registering a cloned value @@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin return cloneInst(context, builder, originalInst, originalInst); } +bool isAutoDiffDecoration(IRInst* decor) +{ + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: + case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + default: + return false; + } +} + /// Clone any decorations from `originalValue` onto `clonedValue` void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue) { @@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o SLANG_UNUSED(context); for (auto originalDecoration : originalValue->getDecorations()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration)) + continue; cloneInst(context, builder, originalDecoration); } @@ -185,6 +262,8 @@ void cloneDecorationsAndChildren( SLANG_UNUSED(context); for (auto originalItem : originalValue->getDecorationsAndChildren()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem)) + continue; cloneInst(context, builder, originalItem); } @@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) cloneDecorationsAndChildren(this, clonedValue, originalValue); addHoistableInst(builder, clonedValue); - return clonedValue; } break; @@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst( { default: break; - + case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: + if (!context->getShared()->useAutodiff) + break; + [[fallthrough]]; case kIROp_HLSLExportDecoration: case kIROp_BindExistentialSlotsDecoration: case kIROp_LayoutDecoration: case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: - case kIROp_ForwardDerivativeDecoration: - case kIROp_UserDefinedBackwardDerivativeDecoration: - case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: case kIROp_NonCopyableTypeDecoration: case kIROp_DynamicDispatchWitnessDecoration: @@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl( return clonedVal; } +bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table) +{ + for (auto decor : table->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_HLSLExportDecoration: + case kIROp_KeepAliveDecoration: + return true; + } + } + + auto conformanceType = getResolvedInstForDecorations(table->getConformanceType()); + + for (auto decor : conformanceType->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ComInterfaceDecoration: + return true; + case kIROp_KnownBuiltinDecoration: + { + auto name = as<IRKnownBuiltinDecoration>(decor)->getName(); + if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr")) + return context->getShared()->useAutodiff; + break; + } + default: + break; + } + } + + return false; +} IRWitnessTable* cloneWitnessTableImpl( IRSpecContextBase* context, @@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl( bool registerValue = true) { IRWitnessTable* clonedTable = dstTable; + IRType* clonedBaseType = nullptr; if (!clonedTable) { - auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); } - cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); + else + { + clonedBaseType = (IRType*)clonedTable->getConformanceType(); + } + if (registerValue) + registerClonedValue(context, clonedTable, originalValues); + + // Set up an IR builder for inserting into the witness table + IRBuilder builderStorage = *context->builder; + IRBuilder* entryBuilder = &builderStorage; + entryBuilder->setInsertInto(clonedTable); + + // Clone decorations first + for (auto decoration : originalTable->getDecorations()) + { + cloneInst(context, entryBuilder, decoration); + } + cloneExtraDecorations(context, clonedTable, originalValues); + + RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo(); + witnessInfo->clonedTable = clonedTable; + witnessInfo->originalTable = originalTable; + + bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable); + + // Clone only the witness table entries that are actually used + for (auto child : originalTable->getDecorationsAndChildren()) + { + if (auto entry = as<IRWitnessTableEntry>(child)) + { + if (!shouldDeepClone) + { + // Skip witness table entries during the first pass, + // and just add them to the deferred work list. + witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry); + continue; + } + } + // Clone any non-entry children as is + cloneInst(context, entryBuilder, child); + } + context->witnessTables.add(witnessInfo); + return clonedTable; } @@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon( } else { + if (oi->getOp() == kIROp_DifferentiableTypeAnnotation && + !context->getShared()->useAutodiff) + continue; cloneInst(context, builder, oi); } } @@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice()); + if (!sym) { String hashedName = getHashedName(mangledName.getUnownedSlice()); - - if (!context->getSymbols().tryGetValue(hashedName, sym)) + sym = context->findSymbols(hashedName.getUnownedSlice()); + if (!sym) { SLANG_UNEXPECTED("no matching IR symbol"); return nullptr; @@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint( if (!clonedFunc) { SLANG_UNEXPECTED("expected entry point to be a function"); - return nullptr; + UNREACHABLE_RETURN(nullptr); } if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context) return context->getShared()->targetReq->getTargetCaps(); } -/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`. +/// Get the most appropriate ("best") capability requirements for `inVal` based on the +/// `targetCaps`. static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps) { IRInst* val = getResolvedInstForDecorations(inVal); @@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage( // with the same mangled name as `originalVal` and try // to pick the "best" one for our target. - auto mangledName = String(originalLinkage->getMangledName()); - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + auto mangledName = originalLinkage->getMangledName(); + IRSpecSymbol* sym = context->findSymbols(mangledName); + if (!sym) { if (!originalVal) return nullptr; @@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv) { sharedContext->symbols.add(mangledName, sym); } + + if (as<IRImportDecoration>(linkage)) + sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true); + else + sharedContext->isImportedSymbol.set(mangledName, false); } void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule) @@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer( IRSpecContext* context, Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted) { - // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed - // and to allow translation of atomic_uint into SPIRV + // Atomic_uint definitions needs to become a storage buffer to follow + // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV IRBuilder builder = *context->builder; @@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer( instToSwitch->setFullType(storageBuffer); // All references to a atomic_uint need to be an element ref. to emulate storage buffer - // usage All function calls must be inlined since storage buffers cannot pass as parameters - // to atomic methods + // usage All function calls must be inlined since storage buffers cannot pass as + // parameters to atomic methods for (auto& i : bindingToInstList.second) { int64_t currOffset = @@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, { case kIROp_GLSLAtomicUintType: { - // atomic_uint are supported by GLSL->VK through converting to a different type - // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK; - // this means that to get SPIR-V to work we must convert the type ourselves to - // an equivlent representation (storage buffer); the added benifit is that then - // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL - // concept, but storageBuffer->RWBuffer is and HLSL concept + // atomic_uint are supported by GLSL->VK through converting to a different + // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by + // SPIR-V->VK; this means that to get SPIR-V to work we must convert the + // type ourselves to an equivlent representation (storage buffer); the added + // benifit is that then HLSL is possible to emit as a target as well since + // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and + // HLSL concept auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout(); auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1)); assert(layoutVal != nullptr); @@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted); } +bool isDiffPairType(IRInst* type) +{ + for (;;) + { + auto type1 = (IRType*)unwrapAttributedType(type); + auto type2 = unwrapArray(type1); + if (type2 == type) + break; + type = type2; + } + return as<IRDifferentialPairTypeBase>(type) != nullptr; +} + +bool doesModuleUseAutodiff(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Call: + if (auto callee = getResolvedInstForDecorations(inst->getOperand(0))) + { + switch (callee->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + return true; + } + } + return false; + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetPrimal: + case kIROp_DifferentialPtrPairGetDifferential: + return true; + case kIROp_StructField: + return isDiffPairType(as<IRStructField>(inst)->getFieldType()); + case kIROp_Param: + return isDiffPairType(inst->getDataType()); + default: + for (auto child : inst->getChildren()) + { + bool isImported = false; + for (auto decor : child->getDecorations()) + { + if (as<IRImportDecoration>(decor)) + { + isImported = true; + break; + } + else if (as<IRAutoPyBindCudaDecoration>(decor)) + { + return true; + } + else if (as<IRAutoPyBindExportInfoDecoration>(decor)) + { + return true; + } + } + if (isImported) + continue; + for (auto decor : child->getDecorations()) + { + if (isAutoDiffDecoration(decor)) + return true; + } + if (doesModuleUseAutodiff(child)) + return true; + } + return false; + } +} + +void cloneUsedWitnessTableEntries(IRSpecContext* context) +{ + bool changed = true; + while (changed) + { + changed = false; + for (Index i = 0; i < context->witnessTables.getCount(); i++) + { + auto table = context->witnessTables[i].get(); + ShortList<UnownedStringSlice> entriesToRemove; + for (auto entry : table->deferredEntries) + { + if (context->deferredWitnessTableEntryKeys.contains(entry.first)) + { + IRBuilder builder(table->clonedTable); + builder.setInsertInto(table->clonedTable); + auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount(); + cloneInst(context, &builder, entry.second); + entriesToRemove.add(entry.first); + if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount()) + changed = true; + } + } + for (auto entry : entriesToRemove) + { + table->deferredEntries.remove(entry); + } + } + } +} + LinkedIR linkIR(CodeGenContext* codeGenContext) { SLANG_PROFILE; @@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) state->target = target; state->targetReq = targetReq; - + auto& irModules = stateStorage.contextStorage.irModules; auto sharedContext = state->getSharedContext(); initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq); state->irModule = sharedContext->module; - // We need to be able to look up IR definitions for any symbols in // modules that the program depends on (transitively). To // accelerate lookup, we will create a symbol table for looking // up IR definitions by their mangled name. // - - List<IRModule*> irModules; - - // Link the core modules. - auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules; - for (auto& m : coreModules) - irModules.add(m->getIRModule()); + auto globalSession = static_cast<Session*>(linkage->getGlobalSession()); + List<IRModule*> builtinModules; + for (auto& m : globalSession->coreModules) + builtinModules.add(m->getIRModule()); // Link modules in the program. - program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); }); - - // Add any modules that were loaded as libraries - for (IRModule* irModule : irModules) - { - insertGlobalValueSymbols(sharedContext, irModule); - } + program->enumerateIRModules( + [&](IRModule* module) + { + if (module->getName() == globalSession->glslModuleName) + builtinModules.add(module); + else + irModules.add(module); + }); - // We will also insert the IR global symbols from the IR module + // We will also consider the IR global symbols from the IR module // attached to the `TargetProgram`, since this module is // responsible for associating layout information to those // global symbols via decorations. // auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout(); - insertGlobalValueSymbols(sharedContext, irModuleForLayout); + if (irModuleForLayout) + irModules.add(irModuleForLayout); + + Index userModuleCount = irModules.getCount(); + irModules.addRange(builtinModules); + ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount); + + // Check if any user module uses auto-diff, if so we will need to link + // additional witnesses and decorations. + for (IRModule* irModule : userModules) + { + if (sharedContext->useAutodiff) + break; + sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst()); + } auto context = state->getContext(); // Combine all of the contents of IRGlobalHashedStringLiterals { StringSlicePool pool(StringSlicePool::Style::Empty); - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { findGlobalHashedStringLiterals(irModule, pool); } @@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // instructions in all the input modules. // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto inst : irModule->getGlobalInsts()) { @@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // We need to copy over exported symbols, // and any global parameters if preserve-params option is set. if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) || - as<IRDifferentiableTypeAnnotation>(inst)) + sharedContext->useAutodiff && + (as<IRDifferentiableTypeAnnotation>(inst) || + inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr)) { auto cloned = cloneValue(context, inst); if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) } } + // In previous steps, we have skipped cloning the witness table entries, and + // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys + // for on-demand cloning. Now we will use the deferred keys to clone the witness table + // entries that are referenced. + cloneUsedWitnessTableEntries(context); + // It is possible that metadata has been attached to the input modules // themselves, which should be copied over to the output module. // @@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // `[assumedWaveSize(...)]` decoration might require that all specified // values match exactly). // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto decoration : irModule->getModuleInst()->getDecorations()) { |
