diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 10 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 33 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 71 | ||||
| -rw-r--r-- | source/slang/slang-compiler.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 390 | ||||
| -rw-r--r-- | source/slang/slang-ir-strip.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-ir-strip.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 43 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ir.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 29 |
21 files changed, 598 insertions, 108 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 267f7b2d4..da1b47e13 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -603,6 +603,7 @@ void static_assert(constexpr bool condition, NativeString errorMessage); /// /// __magic_type(DifferentiableType) +[KnownBuiltin("IDifferentiable")] interface IDifferentiable { // Note: the compiler implementation requires the `Differential` associated type to be defined @@ -645,6 +646,7 @@ interface IDifferentiable /// @remarks Support for this interface is still experimental and subject to change. /// __magic_type(DifferentiablePtrType) +[KnownBuiltin("IDifferentiablePtr")] interface IDifferentiablePtrType { __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) ) @@ -2598,20 +2600,20 @@ for(auto fixity : kIncDecFixities) $(fixity.qual) __generic<T : __BuiltinArithmeticType> [__unsafeForceInlineEarly] -T operator$(op.name)(in out T value) -{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } +T operator$(op.name)( in out T value) +{ $(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); } $(fixity.qual) __generic<T : __BuiltinArithmeticType, let N : int> [__unsafeForceInlineEarly] vector<T,N> operator$(op.name)(in out vector<T,N> value) -{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } +{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); } $(fixity.qual) __generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int> [__unsafeForceInlineEarly] matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value) -{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } +{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); } $(fixity.qual) __generic<T, let addrSpace : uint64_t> diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 790dfaa79..d32584db0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -176,7 +176,7 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute; // for internal use. // [__AutoDiffBuiltin] -export struct NullDifferential : IDifferentiable +struct NullDifferential : IDifferentiable { // for now, we'll use at least one field to make sure the type is non-empty uint dummy; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 173b97fd3..78eb48a33 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20309,7 +20309,7 @@ void ReorderThread( HitObject HitOrMiss ) /// /// There doesn't appear to be an equivalent for debugBreak for HLSL - +[require(glsl)] __specialized_for_target(glsl) [[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]] void __glslDebugBreak(); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5b5e05b73..0c42817c8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10452,6 +10452,12 @@ void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDec { return; } + + if (getText(moduleDecl->getName()) == "glsl") + { + getShared()->glslModuleDecl = moduleDecl; + } + importedModulesList.add(moduleDecl); importedModulesSet.add(moduleDecl); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 6438a91e3..edb199299 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -172,15 +172,17 @@ struct BasicTypeKeyPair struct OperatorOverloadCacheKey { - intptr_t operatorName; + int32_t operatorName; + bool isGLSLMode; BasicTypeKey args[2]; bool operator==(OperatorOverloadCacheKey key) const { - return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1]; + return operatorName == key.operatorName && args[0] == key.args[0] && + args[1] == key.args[1] && isGLSLMode == key.isGLSLMode; } HashCode getHashCode() const { - return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw()); + return combineHash(operatorName, args[0].getRaw(), args[1].getRaw(), isGLSLMode ? 1 : 0); } bool fromOperatorExpr(OperatorExpr* opExpr) { @@ -299,10 +301,28 @@ struct OverloadCandidate SubstitutionSet subst; }; -struct TypeCheckingCache +struct ResolvedOperatorOverload { - Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; + // The resolved decl. + Decl* decl; + + // The cached overload candidate in the current TypeCheckingCache. + // Note that a `OverloadCandidate` object is not migratable over different + // Linkages (compile sessions), so we will need to use `cacheVersion` to track + // if this `candidate` is valid for the current session. If not, we will + // recreate it from `decl`. + OverloadCandidate candidate; + // The version of the TypeCheckingCache for which the cached candidate is valid. + int cacheVersion; +}; + +struct TypeCheckingCache : public RefObject +{ + Dictionary<OperatorOverloadCacheKey, ResolvedOperatorOverload> resolvedOperatorOverloadCache; Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; + + // The version used to invalidate the cached declRefs in ResolvedOperatorOverload entries. + int version = 0; }; enum class CoercionSite @@ -635,6 +655,9 @@ struct SharedSemanticsContext : public RefObject DiagnosticSink* m_sink = nullptr; + // Whether the current module has imported the GLSL module. + ModuleDecl* glslModuleDecl = nullptr; + /// (optional) modules that comes from previously processed translation units in the /// front-end request that are made visible to the module being checked. This allows /// `import` to use them instead of trying to find the files in file system. diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b75f95f9a..f548fb819 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2509,27 +2509,6 @@ String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context) Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) { OverloadResolveContext context; - // check if this is a core module operator call, if so we want to use cached results - // to speed up compilation - bool shouldAddToCache = false; - OperatorOverloadCacheKey key; - TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); - if (auto opExpr = as<OperatorExpr>(expr)) - { - if (key.fromOperatorExpr(opExpr)) - { - OverloadCandidate candidate; - if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate)) - { - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; - } - else - { - shouldAddToCache = true; - } - } - } // Look at the base expression for the call, and figure out how to invoke it. auto funcExpr = expr->functionExpr; @@ -2569,6 +2548,43 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) context.loc = expr->loc; context.sourceScope = m_outerScope; context.baseExpr = GetBaseExpr(funcExpr); + + // check if this is a core module operator call, if so we want to use cached results + // to speed up compilation + bool shouldAddToCache = false; + OperatorOverloadCacheKey key; + TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); + if (auto opExpr = as<OperatorExpr>(expr)) + { + if (key.fromOperatorExpr(opExpr)) + { + key.isGLSLMode = getShared()->glslModuleDecl != nullptr; + ResolvedOperatorOverload candidate; + if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate)) + { + // We should only use the cached candidate if it is persistent direct declref + // created from GlobalSession's ASTBuilder, or it is created in the current Linkage. + if (candidate.cacheVersion == typeCheckingCache->version || + as<DirectDeclRef>(candidate.candidate.item.declRef.declRefBase)) + { + context.bestCandidateStorage = candidate.candidate; + context.bestCandidate = &context.bestCandidateStorage; + } + else + { + LookupResultItem overloadCandidate = {}; + overloadCandidate.declRef = getOuterGenericOrSelf(candidate.decl); + AddDeclRefOverloadCandidates(overloadCandidate, context, 0); + shouldAddToCache = true; + } + } + else + { + shouldAddToCache = true; + } + } + } + // We run a special case here where an `InvokeExpr` // with a single argument where the base/func expression names // a type should always be treated as an explicit type coercion @@ -2731,7 +2747,18 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) // We will report errors for this one candidate, then, to give // the user the most help we can. if (shouldAddToCache) - typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; + { + if (isFromCoreModule(context.bestCandidate->item.declRef.getDecl()) || + getShared()->glslModuleDecl == + getModuleDecl(context.bestCandidate->item.declRef.getDecl())) + { + ResolvedOperatorOverload overloadResult; + overloadResult.candidate = *context.bestCandidate; + overloadResult.decl = context.bestCandidate->item.declRef.getDecl(); + overloadResult.cacheVersion = typeCheckingCache->version; + typeCheckingCache->resolvedOperatorOverloadCache[key] = overloadResult; + } + } // Now that we have resolved the overload candidate, we need to undo an `openExistential` // operation that was applied to `out` arguments. diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 38725fff3..3e00e4b04 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2209,7 +2209,7 @@ public: TypeCheckingCache* getTypeCheckingCache(); void destroyTypeCheckingCache(); - TypeCheckingCache* m_typeCheckingCache = nullptr; + RefPtr<RefObject> m_typeCheckingCache = nullptr; // Modules that have been dynamically loaded via `import` // @@ -3589,6 +3589,9 @@ public: int m_typeDictionarySize = 0; + RefPtr<RefObject> m_typeCheckingCache; + TypeCheckingCache* getTypeCheckingCache(); + private: struct BuiltinModuleInfo { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index cf653d001..03f52a701 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -49,6 +49,7 @@ DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'") DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'") DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'") DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'") +DIAGNOSTIC(-1, Note, seeCallOfFunc, "see call to '$0'") DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition") DIAGNOSTIC( -1, @@ -2309,7 +2310,7 @@ DIAGNOSTIC( 41402, Error, staticAssertionConditionNotConstant, - "condition for static assertion cannot be evaluated at the compile-time.") + "condition for static assertion cannot be evaluated at compile time.") DIAGNOSTIC( 41402, diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index e20a4a90f..847c5b55c 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -475,6 +475,31 @@ void calcRequiredLoweringPassSet( } } +void diagnoseCallStack(IRInst* inst, DiagnosticSink* sink) +{ + static const int maxDepth = 5; + for (int i = 0; i < maxDepth; i++) + { + auto func = getParentFunc(inst); + if (!func) + return; + bool shouldContinue = false; + for (auto use = func->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto call = as<IRCall>(user)) + { + sink->diagnose(call, Diagnostics::seeCallOfFunc, func); + inst = call; + shouldContinue = true; + break; + } + } + if (!shouldContinue) + return; + } +} + bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink) { switch (inst->getOp()) @@ -498,6 +523,7 @@ bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink) { sink->diagnose(inst, Diagnostics::staticAssertionFailureWithoutMessage); } + diagnoseCallStack(inst, sink); } } else diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6fc7d56ad..d8500a694 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -2012,6 +2012,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeArrayFromElement: case kIROp_MakeTuple: case kIROp_MakeValuePack: + case kIROp_BuiltinCast: return transcribeConstruct(builder, origInst); case kIROp_MakeStruct: return transcribeMakeStruct(builder, origInst); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 9075002e0..40dcb1b51 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1026,12 +1026,9 @@ IRInst* AutoDiffSharedContext::findDifferentiableInterface() { for (auto globalInst : module->getGlobalInsts()) { - // TODO: This seems like a particularly dangerous way to look for an interface. - // See if we can lower IDifferentiable to a separate IR inst. - // if (auto intf = as<IRInterfaceType>(globalInst)) { - if (auto decor = intf->findDecoration<IRNameHintDecoration>()) + if (auto decor = intf->findDecoration<IRKnownBuiltinDecoration>()) { if (decor->getName() == toSlice("IDifferentiable")) { @@ -1261,7 +1258,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) } addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness()); - +#if 0 // TODO: Is this really needed? if (!as<IRInterfaceType>(item->getBaseType()) && !as<IRAssociatedType>(item->getBaseType())) @@ -1314,6 +1311,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) addTypeToDictionary((IRType*)diffType, diffWitness); } } +#endif } } } diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 53bbbda9e..125ab71d0 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -52,15 +52,28 @@ struct IRSharedSpecContext // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. - typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary; + typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary; SymbolDictionary symbols; + Dictionary<ImmutableHashedString, bool> isImportedSymbol; + + bool useAutodiff = false; + IRBuilder builderStorage; // The "global" specialization environment. IRSpecEnv globalEnv; }; +void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv); + +struct WitnessTableCloneInfo : RefObject +{ + IRWitnessTable* clonedTable; + IRWitnessTable* originalTable; + Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries; +}; + struct IRSpecContextBase { IRSharedSpecContext* shared; @@ -69,7 +82,27 @@ struct IRSpecContextBase IRModule* getModule() { return getShared()->module; } - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } + List<IRModule*> irModules; + + HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys; + List<RefPtr<WitnessTableCloneInfo>> witnessTables; + + IRSpecSymbol* findSymbols(UnownedStringSlice mangledName) + { + ImmutableHashedString hashedName(mangledName); + RefPtr<IRSpecSymbol> symbol; + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + for (auto m : irModules) + { + for (auto inst : m->findSymbolByMangledName(hashedName)) + insertGlobalValueSymbol(shared, inst); + } + if (shared->symbols.tryGetValue(hashedName, symbol)) + return symbol; + shared->symbols[hashedName] = nullptr; + return nullptr; + } // The current specialization environment to use. IRSpecEnv* env = nullptr; @@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst // an `Add()` call. // context->getEnv()->clonedValues[originalValue] = clonedValue; + + switch (clonedValue->getOp()) + { + case kIROp_LookupWitness: + + // If `originalVal` represents a witness table entry key, add the key + // to witnessTableEntryWorkList. + context->deferredWitnessTableEntryKeys.add( + getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey())); + break; + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + if (context->getShared()->useAutodiff) + { + if (auto key = as<IRStructKey>(clonedValue->getOperand(0))) + context->deferredWitnessTableEntryKeys.add(getMangledName(key)); + } + break; + } } // Information on values to use when registering a cloned value @@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin return cloneInst(context, builder, originalInst, originalInst); } +bool isAutoDiffDecoration(IRInst* decor) +{ + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: + case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + default: + return false; + } +} + /// Clone any decorations from `originalValue` onto `clonedValue` void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue) { @@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o SLANG_UNUSED(context); for (auto originalDecoration : originalValue->getDecorations()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration)) + continue; cloneInst(context, builder, originalDecoration); } @@ -185,6 +262,8 @@ void cloneDecorationsAndChildren( SLANG_UNUSED(context); for (auto originalItem : originalValue->getDecorationsAndChildren()) { + if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem)) + continue; cloneInst(context, builder, originalItem); } @@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) cloneDecorationsAndChildren(this, clonedValue, originalValue); addHoistableInst(builder, clonedValue); - return clonedValue; } break; @@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst( { default: break; - + case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: + if (!context->getShared()->useAutodiff) + break; + [[fallthrough]]; case kIROp_HLSLExportDecoration: case kIROp_BindExistentialSlotsDecoration: case kIROp_LayoutDecoration: case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: - case kIROp_ForwardDerivativeDecoration: - case kIROp_UserDefinedBackwardDerivativeDecoration: - case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: case kIROp_NonCopyableTypeDecoration: case kIROp_DynamicDispatchWitnessDecoration: @@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl( return clonedVal; } +bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table) +{ + for (auto decor : table->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_HLSLExportDecoration: + case kIROp_KeepAliveDecoration: + return true; + } + } + + auto conformanceType = getResolvedInstForDecorations(table->getConformanceType()); + + for (auto decor : conformanceType->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ComInterfaceDecoration: + return true; + case kIROp_KnownBuiltinDecoration: + { + auto name = as<IRKnownBuiltinDecoration>(decor)->getName(); + if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr")) + return context->getShared()->useAutodiff; + break; + } + default: + break; + } + } + + return false; +} IRWitnessTable* cloneWitnessTableImpl( IRSpecContextBase* context, @@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl( bool registerValue = true) { IRWitnessTable* clonedTable = dstTable; + IRType* clonedBaseType = nullptr; if (!clonedTable) { - auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); } - cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); + else + { + clonedBaseType = (IRType*)clonedTable->getConformanceType(); + } + if (registerValue) + registerClonedValue(context, clonedTable, originalValues); + + // Set up an IR builder for inserting into the witness table + IRBuilder builderStorage = *context->builder; + IRBuilder* entryBuilder = &builderStorage; + entryBuilder->setInsertInto(clonedTable); + + // Clone decorations first + for (auto decoration : originalTable->getDecorations()) + { + cloneInst(context, entryBuilder, decoration); + } + cloneExtraDecorations(context, clonedTable, originalValues); + + RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo(); + witnessInfo->clonedTable = clonedTable; + witnessInfo->originalTable = originalTable; + + bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable); + + // Clone only the witness table entries that are actually used + for (auto child : originalTable->getDecorationsAndChildren()) + { + if (auto entry = as<IRWitnessTableEntry>(child)) + { + if (!shouldDeepClone) + { + // Skip witness table entries during the first pass, + // and just add them to the deferred work list. + witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry); + continue; + } + } + // Clone any non-entry children as is + cloneInst(context, entryBuilder, child); + } + context->witnessTables.add(witnessInfo); + return clonedTable; } @@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon( } else { + if (oi->getOp() == kIROp_DifferentiableTypeAnnotation && + !context->getShared()->useAutodiff) + continue; cloneInst(context, builder, oi); } } @@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice()); + if (!sym) { String hashedName = getHashedName(mangledName.getUnownedSlice()); - - if (!context->getSymbols().tryGetValue(hashedName, sym)) + sym = context->findSymbols(hashedName.getUnownedSlice()); + if (!sym) { SLANG_UNEXPECTED("no matching IR symbol"); return nullptr; @@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint( if (!clonedFunc) { SLANG_UNEXPECTED("expected entry point to be a function"); - return nullptr; + UNREACHABLE_RETURN(nullptr); } if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context) return context->getShared()->targetReq->getTargetCaps(); } -/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`. +/// Get the most appropriate ("best") capability requirements for `inVal` based on the +/// `targetCaps`. static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps) { IRInst* val = getResolvedInstForDecorations(inVal); @@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage( // with the same mangled name as `originalVal` and try // to pick the "best" one for our target. - auto mangledName = String(originalLinkage->getMangledName()); - RefPtr<IRSpecSymbol> sym; - if (!context->getSymbols().tryGetValue(mangledName, sym)) + auto mangledName = originalLinkage->getMangledName(); + IRSpecSymbol* sym = context->findSymbols(mangledName); + if (!sym) { if (!originalVal) return nullptr; @@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv) { sharedContext->symbols.add(mangledName, sym); } + + if (as<IRImportDecoration>(linkage)) + sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true); + else + sharedContext->isImportedSymbol.set(mangledName, false); } void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule) @@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer( IRSpecContext* context, Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted) { - // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed - // and to allow translation of atomic_uint into SPIRV + // Atomic_uint definitions needs to become a storage buffer to follow + // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV IRBuilder builder = *context->builder; @@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer( instToSwitch->setFullType(storageBuffer); // All references to a atomic_uint need to be an element ref. to emulate storage buffer - // usage All function calls must be inlined since storage buffers cannot pass as parameters - // to atomic methods + // usage All function calls must be inlined since storage buffers cannot pass as + // parameters to atomic methods for (auto& i : bindingToInstList.second) { int64_t currOffset = @@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, { case kIROp_GLSLAtomicUintType: { - // atomic_uint are supported by GLSL->VK through converting to a different type - // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK; - // this means that to get SPIR-V to work we must convert the type ourselves to - // an equivlent representation (storage buffer); the added benifit is that then - // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL - // concept, but storageBuffer->RWBuffer is and HLSL concept + // atomic_uint are supported by GLSL->VK through converting to a different + // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by + // SPIR-V->VK; this means that to get SPIR-V to work we must convert the + // type ourselves to an equivlent representation (storage buffer); the added + // benifit is that then HLSL is possible to emit as a target as well since + // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and + // HLSL concept auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout(); auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1)); assert(layoutVal != nullptr); @@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram, convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted); } +bool isDiffPairType(IRInst* type) +{ + for (;;) + { + auto type1 = (IRType*)unwrapAttributedType(type); + auto type2 = unwrapArray(type1); + if (type2 == type) + break; + type = type2; + } + return as<IRDifferentialPairTypeBase>(type) != nullptr; +} + +bool doesModuleUseAutodiff(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Call: + if (auto callee = getResolvedInstForDecorations(inst->getOperand(0))) + { + switch (callee->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + return true; + } + } + return false; + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetPrimal: + case kIROp_DifferentialPtrPairGetDifferential: + return true; + case kIROp_StructField: + return isDiffPairType(as<IRStructField>(inst)->getFieldType()); + case kIROp_Param: + return isDiffPairType(inst->getDataType()); + default: + for (auto child : inst->getChildren()) + { + bool isImported = false; + for (auto decor : child->getDecorations()) + { + if (as<IRImportDecoration>(decor)) + { + isImported = true; + break; + } + else if (as<IRAutoPyBindCudaDecoration>(decor)) + { + return true; + } + else if (as<IRAutoPyBindExportInfoDecoration>(decor)) + { + return true; + } + } + if (isImported) + continue; + for (auto decor : child->getDecorations()) + { + if (isAutoDiffDecoration(decor)) + return true; + } + if (doesModuleUseAutodiff(child)) + return true; + } + return false; + } +} + +void cloneUsedWitnessTableEntries(IRSpecContext* context) +{ + bool changed = true; + while (changed) + { + changed = false; + for (Index i = 0; i < context->witnessTables.getCount(); i++) + { + auto table = context->witnessTables[i].get(); + ShortList<UnownedStringSlice> entriesToRemove; + for (auto entry : table->deferredEntries) + { + if (context->deferredWitnessTableEntryKeys.contains(entry.first)) + { + IRBuilder builder(table->clonedTable); + builder.setInsertInto(table->clonedTable); + auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount(); + cloneInst(context, &builder, entry.second); + entriesToRemove.add(entry.first); + if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount()) + changed = true; + } + } + for (auto entry : entriesToRemove) + { + table->deferredEntries.remove(entry); + } + } + } +} + LinkedIR linkIR(CodeGenContext* codeGenContext) { SLANG_PROFILE; @@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) state->target = target; state->targetReq = targetReq; - + auto& irModules = stateStorage.contextStorage.irModules; auto sharedContext = state->getSharedContext(); initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq); state->irModule = sharedContext->module; - // We need to be able to look up IR definitions for any symbols in // modules that the program depends on (transitively). To // accelerate lookup, we will create a symbol table for looking // up IR definitions by their mangled name. // - - List<IRModule*> irModules; - - // Link the core modules. - auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules; - for (auto& m : coreModules) - irModules.add(m->getIRModule()); + auto globalSession = static_cast<Session*>(linkage->getGlobalSession()); + List<IRModule*> builtinModules; + for (auto& m : globalSession->coreModules) + builtinModules.add(m->getIRModule()); // Link modules in the program. - program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); }); - - // Add any modules that were loaded as libraries - for (IRModule* irModule : irModules) - { - insertGlobalValueSymbols(sharedContext, irModule); - } + program->enumerateIRModules( + [&](IRModule* module) + { + if (module->getName() == globalSession->glslModuleName) + builtinModules.add(module); + else + irModules.add(module); + }); - // We will also insert the IR global symbols from the IR module + // We will also consider the IR global symbols from the IR module // attached to the `TargetProgram`, since this module is // responsible for associating layout information to those // global symbols via decorations. // auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout(); - insertGlobalValueSymbols(sharedContext, irModuleForLayout); + if (irModuleForLayout) + irModules.add(irModuleForLayout); + + Index userModuleCount = irModules.getCount(); + irModules.addRange(builtinModules); + ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount); + + // Check if any user module uses auto-diff, if so we will need to link + // additional witnesses and decorations. + for (IRModule* irModule : userModules) + { + if (sharedContext->useAutodiff) + break; + sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst()); + } auto context = state->getContext(); // Combine all of the contents of IRGlobalHashedStringLiterals { StringSlicePool pool(StringSlicePool::Style::Empty); - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { findGlobalHashedStringLiterals(irModule, pool); } @@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // instructions in all the input modules. // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto inst : irModule->getGlobalInsts()) { @@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // We need to copy over exported symbols, // and any global parameters if preserve-params option is set. if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) || - as<IRDifferentiableTypeAnnotation>(inst)) + sharedContext->useAutodiff && + (as<IRDifferentiableTypeAnnotation>(inst) || + inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr)) { auto cloned = cloneValue(context, inst); if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) @@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) } } + // In previous steps, we have skipped cloning the witness table entries, and + // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys + // for on-demand cloning. Now we will use the deferred keys to clone the witness table + // entries that are referenced. + cloneUsedWitnessTableEntries(context); + // It is possible that metadata has been attached to the input modules // themselves, which should be copied over to the output module. // @@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) // `[assumedWaveSize(...)]` decoration might require that all specified // values match exactly). // - for (IRModule* irModule : irModules) + for (IRModule* irModule : userModules) { for (auto decoration : irModule->getModuleInst()->getDecorations()) { diff --git a/source/slang/slang-ir-strip.cpp b/source/slang/slang-ir-strip.cpp index e9fbbceb6..6c64bfca2 100644 --- a/source/slang/slang-ir-strip.cpp +++ b/source/slang/slang-ir-strip.cpp @@ -51,4 +51,34 @@ void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& optio _stripFrontEndOnlyInstructionsRec(module->getModuleInst(), options); } +void stripImportedWitnessTable(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + auto inst = globalInst; + switch (globalInst->getOp()) + { + case kIROp_Generic: + inst = findInnerMostGenericReturnVal(as<IRGeneric>(globalInst)); + break; + case kIROp_WitnessTable: + break; + default: + continue; + } + if (inst->getOp() != kIROp_WitnessTable) + continue; + if (!globalInst->findDecoration<IRImportDecoration>()) + continue; + IRInst* nextChild = nullptr; + for (auto child = inst->getFirstChild(); child;) + { + nextChild = child->getNextInst(); + if (child->getOp() == kIROp_WitnessTable) + child->removeAndDeallocate(); + child = nextChild; + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-strip.h b/source/slang/slang-ir-strip.h index 9d97669a1..f6f53aeda 100644 --- a/source/slang/slang-ir-strip.h +++ b/source/slang/slang-ir-strip.h @@ -13,5 +13,9 @@ struct IRStripOptions /// Strip out instructions that should only be used by the front-end. void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& options); + +/// Strip witness table entries from imported witness tables. +void stripImportedWitnessTable(IRModule* module); + } // namespace Slang #pragma once diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index a3cf28a68..dbd6ac099 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2175,4 +2175,14 @@ void legalizeDefUse(IRGlobalValueWithCode* func) } } +UnownedStringSlice getMangledName(IRInst* inst) +{ + for (auto decor : inst->getDecorations()) + { + if (auto linkageDecor = as<IRLinkageDecoration>(decor)) + return linkageDecor->getMangledName(); + } + return UnownedStringSlice(); +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 46e0105f5..610524754 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -377,6 +377,8 @@ Int getSpecializationConstantId(IRGlobalParam* param); void legalizeDefUse(IRGlobalValueWithCode* func); +UnownedStringSlice getMangledName(IRInst* inst); + } // namespace Slang #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f28f61ffc..fb274c4a0 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4538,6 +4538,18 @@ RefPtr<IRModule> IRModule::create(Session* session) return module; } +void IRModule::buildMangledNameToGlobalInstMap() +{ + m_mapMangledNameToGlobalInst.clear(); + for (auto inst : getGlobalInsts()) + { + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + { + m_mapMangledNameToGlobalInst[linkageDecor->getMangledName()].add(inst); + } + } +} + IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func) { IRAnalysis* analysis = m_mapInstToAnalysis.tryGetValue(func); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 8f53c9f14..ecf5d1c66 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2374,6 +2374,15 @@ public: m_obfuscatedSourceMap = sourceMap; } + ArrayView<IRInst*> findSymbolByMangledName(const ImmutableHashedString& mangledName) const + { + if (auto list = m_mapMangledNameToGlobalInst.tryGetValue(mangledName)) + return list->getArrayView(); + return {}; + } + + void buildMangledNameToGlobalInstMap(); + IRDeduplicationContext* getDeduplicationContext() const { return &m_deduplicationContext; } IRDominatorTree* findDominatorTree(IRGlobalValueWithCode* func) @@ -2392,6 +2401,9 @@ public: IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } + Name* getName() const { return m_name; } + void setName(Name* name) { m_name = name; } + /// Create an empty instruction with the `op` opcode and space for /// a number of operands given by `operandCount`. /// @@ -2444,6 +2456,9 @@ private: /// IRModuleInst* m_moduleInst = nullptr; + // The name of the module. + Name* m_name = nullptr; + /// The memory arena from which all IR instructions (and any associated state) in this module /// are allocated. MemoryArena m_memoryArena; @@ -2459,6 +2474,8 @@ private: ComPtr<IBoxValue<SourceMap>> m_obfuscatedSourceMap; Dictionary<IRInst*, IRAnalysis> m_mapInstToAnalysis; + + Dictionary<ImmutableHashedString, List<IRInst*>> m_mapMangledNameToGlobalInst; }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index fbe6d8a84..e5037bf04 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8168,24 +8168,34 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // If the witness table is for a COM interface, always keep it alive. if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>()) { - subBuilder->addPublicDecoration(irWitnessTable); + subBuilder->addHLSLExportDecoration(irWitnessTable); } - if (parentDecl->findModifier<HLSLExportModifier>()) + for (auto mod : parentDecl->modifiers) { - subBuilder->addHLSLExportDecoration(irWitnessTable); - subBuilder->addKeepAliveDecoration(irWitnessTable); + if (as<HLSLExportModifier>(mod)) + { + subBuilder->addHLSLExportDecoration(irWitnessTable); + subBuilder->addKeepAliveDecoration(irWitnessTable); + } + else if (as<AutoDiffBuiltinAttribute>(mod)) + { + subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable); + } } // Make sure that all the entries in the witness table have been filled in, // including any cases where there are sub-witness-tables for conformances - Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable; - lowerWitnessTable( - subContext, - inheritanceDecl->witnessTable, - irWitnessTable, - mapASTToIRWitnessTable); - + bool isExplicitExtern = false; + if (!isImportedDecl(context, parentDecl, isExplicitExtern)) + { + Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable; + lowerWitnessTable( + subContext, + inheritanceDecl->witnessTable, + irWitnessTable, + mapASTToIRWitnessTable); + } irWitnessTable->moveToEnd(); return LoweredValInfo::simple( @@ -11536,6 +11546,8 @@ RefPtr<IRModule> generateIRForTranslationUnit( RefPtr<IRModule> module = IRModule::create(session); + module->setName(translationUnit->getModuleDecl()->getName()); + IRBuilder builderStorage(module); IRBuilder* builder = &builderStorage; @@ -11804,6 +11816,8 @@ RefPtr<IRModule> generateIRForTranslationUnit( stripOptions.stripSourceLocs = false; stripFrontEndOnlyInstructions(module, stripOptions); + stripImportedWitnessTable(module); + // Stripping out decorations could leave some dead code behind // in the module, and in some cases that extra code is also // undesirable (e.g., the string literals referenced by name-hint @@ -11847,6 +11861,8 @@ RefPtr<IRModule> generateIRForTranslationUnit( &writer); } + module->buildMangledNameToGlobalInstMap(); + return module; } @@ -11884,7 +11900,7 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor context->irBuilder = builder; componentType->acceptVisitor(this, nullptr); - + module->buildMangledNameToGlobalInstMap(); return module; } @@ -12040,6 +12056,7 @@ struct TypeConformanceIRGenContext { builder->addSequentialIDDecoration(witness, conformanceIdOverride); } + module->buildMangledNameToGlobalInstMap(); return module; } }; @@ -12507,7 +12524,7 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) // Eliminate any dead code eliminateDeadCode(irModule, options); } - + irModule->buildMangledNameToGlobalInstMap(); m_irModuleForLayout = irModule; return irModule; } diff --git a/source/slang/slang-serialize-ir.cpp b/source/slang/slang-serialize-ir.cpp index c6926735a..b8a1183b3 100644 --- a/source/slang/slang-serialize-ir.cpp +++ b/source/slang/slang-serialize-ir.cpp @@ -1047,6 +1047,8 @@ Result IRSerialReader::read( } } + outModule->buildMangledNameToGlobalInstMap(); + return SLANG_OK; } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index c316974f1..ec90ee418 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -410,6 +410,11 @@ const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name) return result; } +TypeCheckingCache* Session::getTypeCheckingCache() +{ + return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); +} + Session::BuiltinModuleInfo Session::getBuiltinModuleInfo(slang::BuiltinModuleName name) { Session::BuiltinModuleInfo result; @@ -700,6 +705,7 @@ SlangResult Session::_readBuiltinModule( module->setModuleDecl(moduleDecl); } + srcModule.irModule->setName(module->getNameObj()); module->setIRModule(srcModule.irModule); // Put in the loaded module map @@ -803,6 +809,10 @@ Session::createSession(slang::SessionDesc const& inDesc, slang::ISession** outSe RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage()); + if (m_typeCheckingCache) + linkage->m_typeCheckingCache = + new TypeCheckingCache(*static_cast<TypeCheckingCache*>(m_typeCheckingCache.get())); + linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode); Int searchPathCount = desc.searchPathCount; @@ -1263,9 +1273,6 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka , m_astBuilder(astBuilder) , m_cmdLineContext(new CommandLineContext()) { - if (builtinLinkage) - m_astBuilder->m_cachedNodes = builtinLinkage->getASTBuilder()->m_cachedNodes; - getNamePool()->setRootNamePool(session->getRootNamePool()); m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); @@ -1297,6 +1304,17 @@ ISlangUnknown* Linkage::getInterface(const Guid& guid) Linkage::~Linkage() { + // Upstream type checking cache. + if (m_typeCheckingCache) + { + auto globalSession = getSessionImpl(); + if (!globalSession->m_typeCheckingCache || + globalSession->getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount() < + getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount()) + { + globalSession->m_typeCheckingCache = m_typeCheckingCache; + } + } destroyTypeCheckingCache(); } @@ -1318,12 +1336,11 @@ TypeCheckingCache* Linkage::getTypeCheckingCache() { m_typeCheckingCache = new TypeCheckingCache(); } - return m_typeCheckingCache; + return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); } void Linkage::destroyTypeCheckingCache() { - delete m_typeCheckingCache; m_typeCheckingCache = nullptr; } @@ -4080,6 +4097,8 @@ RefPtr<Module> Linkage::loadModuleFromIRBlobImpl( loadedModulesList.add(resultModule); resultModule->setPathInfo(filePathInfo); + resultModule->getIRModule()->setName(resultModule->getNameObj()); + return resultModule; } |
