summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-02-23 10:31:05 -0800
committerGitHub <noreply@github.com>2025-02-23 10:31:05 -0800
commit51ad07d1fbffd41c758eba172aa77ebba3204924 (patch)
treefadd788714c4ad37830846b0274d56b5ae1eff56 /source/slang
parent0101e5ab59a1678ed7212913c3880edfaf039537 (diff)
Improve performance when compiling small shaders. (#6396)
Improve performance when compiling small shaders. Avoid copying witness table entries that are not getting used during linking. Avoid copying auto-diff related decorations and derivative functions during linking, if the user modules doesn't use autodiff. Cache operator overload resolution results on global session, so each new Session doesn't need to repetitively run through overload resolution from scratch.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang10
-rw-r--r--source/slang/diff.meta.slang2
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-check-decl.cpp6
-rw-r--r--source/slang/slang-check-impl.h33
-rw-r--r--source/slang/slang-check-overload.cpp71
-rw-r--r--source/slang/slang-compiler.h5
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit.cpp26
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-autodiff.cpp8
-rw-r--r--source/slang/slang-ir-link.cpp390
-rw-r--r--source/slang/slang-ir-strip.cpp30
-rw-r--r--source/slang/slang-ir-strip.h4
-rw-r--r--source/slang/slang-ir-util.cpp10
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-ir.h17
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
-rw-r--r--source/slang/slang-serialize-ir.cpp2
-rw-r--r--source/slang/slang.cpp29
21 files changed, 598 insertions, 108 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 267f7b2d4..da1b47e13 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -603,6 +603,7 @@ void static_assert(constexpr bool condition, NativeString errorMessage);
///
///
__magic_type(DifferentiableType)
+[KnownBuiltin("IDifferentiable")]
interface IDifferentiable
{
// Note: the compiler implementation requires the `Differential` associated type to be defined
@@ -645,6 +646,7 @@ interface IDifferentiable
/// @remarks Support for this interface is still experimental and subject to change.
///
__magic_type(DifferentiablePtrType)
+[KnownBuiltin("IDifferentiablePtr")]
interface IDifferentiablePtrType
{
__builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
@@ -2598,20 +2600,20 @@ for(auto fixity : kIncDecFixities)
$(fixity.qual)
__generic<T : __BuiltinArithmeticType>
[__unsafeForceInlineEarly]
-T operator$(op.name)(in out T value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+T operator$(op.name)( in out T value)
+{ $(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let N : int>
[__unsafeForceInlineEarly]
vector<T,N> operator$(op.name)(in out vector<T,N> value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int>
[__unsafeForceInlineEarly]
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T, let addrSpace : uint64_t>
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 790dfaa79..d32584db0 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -176,7 +176,7 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
// for internal use.
//
[__AutoDiffBuiltin]
-export struct NullDifferential : IDifferentiable
+struct NullDifferential : IDifferentiable
{
// for now, we'll use at least one field to make sure the type is non-empty
uint dummy;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 173b97fd3..78eb48a33 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -20309,7 +20309,7 @@ void ReorderThread( HitObject HitOrMiss )
///
/// There doesn't appear to be an equivalent for debugBreak for HLSL
-
+[require(glsl)]
__specialized_for_target(glsl)
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
void __glslDebugBreak();
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5b5e05b73..0c42817c8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10452,6 +10452,12 @@ void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDec
{
return;
}
+
+ if (getText(moduleDecl->getName()) == "glsl")
+ {
+ getShared()->glslModuleDecl = moduleDecl;
+ }
+
importedModulesList.add(moduleDecl);
importedModulesSet.add(moduleDecl);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 6438a91e3..edb199299 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -172,15 +172,17 @@ struct BasicTypeKeyPair
struct OperatorOverloadCacheKey
{
- intptr_t operatorName;
+ int32_t operatorName;
+ bool isGLSLMode;
BasicTypeKey args[2];
bool operator==(OperatorOverloadCacheKey key) const
{
- return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1];
+ return operatorName == key.operatorName && args[0] == key.args[0] &&
+ args[1] == key.args[1] && isGLSLMode == key.isGLSLMode;
}
HashCode getHashCode() const
{
- return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw());
+ return combineHash(operatorName, args[0].getRaw(), args[1].getRaw(), isGLSLMode ? 1 : 0);
}
bool fromOperatorExpr(OperatorExpr* opExpr)
{
@@ -299,10 +301,28 @@ struct OverloadCandidate
SubstitutionSet subst;
};
-struct TypeCheckingCache
+struct ResolvedOperatorOverload
{
- Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
+ // The resolved decl.
+ Decl* decl;
+
+ // The cached overload candidate in the current TypeCheckingCache.
+ // Note that a `OverloadCandidate` object is not migratable over different
+ // Linkages (compile sessions), so we will need to use `cacheVersion` to track
+ // if this `candidate` is valid for the current session. If not, we will
+ // recreate it from `decl`.
+ OverloadCandidate candidate;
+ // The version of the TypeCheckingCache for which the cached candidate is valid.
+ int cacheVersion;
+};
+
+struct TypeCheckingCache : public RefObject
+{
+ Dictionary<OperatorOverloadCacheKey, ResolvedOperatorOverload> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
+
+ // The version used to invalidate the cached declRefs in ResolvedOperatorOverload entries.
+ int version = 0;
};
enum class CoercionSite
@@ -635,6 +655,9 @@ struct SharedSemanticsContext : public RefObject
DiagnosticSink* m_sink = nullptr;
+ // Whether the current module has imported the GLSL module.
+ ModuleDecl* glslModuleDecl = nullptr;
+
/// (optional) modules that comes from previously processed translation units in the
/// front-end request that are made visible to the module being checked. This allows
/// `import` to use them instead of trying to find the files in file system.
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index b75f95f9a..f548fb819 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -2509,27 +2509,6 @@ String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context)
Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
{
OverloadResolveContext context;
- // check if this is a core module operator call, if so we want to use cached results
- // to speed up compilation
- bool shouldAddToCache = false;
- OperatorOverloadCacheKey key;
- TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
- if (auto opExpr = as<OperatorExpr>(expr))
- {
- if (key.fromOperatorExpr(opExpr))
- {
- OverloadCandidate candidate;
- if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
- {
- context.bestCandidateStorage = candidate;
- context.bestCandidate = &context.bestCandidateStorage;
- }
- else
- {
- shouldAddToCache = true;
- }
- }
- }
// Look at the base expression for the call, and figure out how to invoke it.
auto funcExpr = expr->functionExpr;
@@ -2569,6 +2548,43 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
context.loc = expr->loc;
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(funcExpr);
+
+ // check if this is a core module operator call, if so we want to use cached results
+ // to speed up compilation
+ bool shouldAddToCache = false;
+ OperatorOverloadCacheKey key;
+ TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
+ if (auto opExpr = as<OperatorExpr>(expr))
+ {
+ if (key.fromOperatorExpr(opExpr))
+ {
+ key.isGLSLMode = getShared()->glslModuleDecl != nullptr;
+ ResolvedOperatorOverload candidate;
+ if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
+ {
+ // We should only use the cached candidate if it is persistent direct declref
+ // created from GlobalSession's ASTBuilder, or it is created in the current Linkage.
+ if (candidate.cacheVersion == typeCheckingCache->version ||
+ as<DirectDeclRef>(candidate.candidate.item.declRef.declRefBase))
+ {
+ context.bestCandidateStorage = candidate.candidate;
+ context.bestCandidate = &context.bestCandidateStorage;
+ }
+ else
+ {
+ LookupResultItem overloadCandidate = {};
+ overloadCandidate.declRef = getOuterGenericOrSelf(candidate.decl);
+ AddDeclRefOverloadCandidates(overloadCandidate, context, 0);
+ shouldAddToCache = true;
+ }
+ }
+ else
+ {
+ shouldAddToCache = true;
+ }
+ }
+ }
+
// We run a special case here where an `InvokeExpr`
// with a single argument where the base/func expression names
// a type should always be treated as an explicit type coercion
@@ -2731,7 +2747,18 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
// We will report errors for this one candidate, then, to give
// the user the most help we can.
if (shouldAddToCache)
- typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;
+ {
+ if (isFromCoreModule(context.bestCandidate->item.declRef.getDecl()) ||
+ getShared()->glslModuleDecl ==
+ getModuleDecl(context.bestCandidate->item.declRef.getDecl()))
+ {
+ ResolvedOperatorOverload overloadResult;
+ overloadResult.candidate = *context.bestCandidate;
+ overloadResult.decl = context.bestCandidate->item.declRef.getDecl();
+ overloadResult.cacheVersion = typeCheckingCache->version;
+ typeCheckingCache->resolvedOperatorOverloadCache[key] = overloadResult;
+ }
+ }
// Now that we have resolved the overload candidate, we need to undo an `openExistential`
// operation that was applied to `out` arguments.
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 38725fff3..3e00e4b04 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2209,7 +2209,7 @@ public:
TypeCheckingCache* getTypeCheckingCache();
void destroyTypeCheckingCache();
- TypeCheckingCache* m_typeCheckingCache = nullptr;
+ RefPtr<RefObject> m_typeCheckingCache = nullptr;
// Modules that have been dynamically loaded via `import`
//
@@ -3589,6 +3589,9 @@ public:
int m_typeDictionarySize = 0;
+ RefPtr<RefObject> m_typeCheckingCache;
+ TypeCheckingCache* getTypeCheckingCache();
+
private:
struct BuiltinModuleInfo
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index cf653d001..03f52a701 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -49,6 +49,7 @@ DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'")
DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'")
DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'")
DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'")
+DIAGNOSTIC(-1, Note, seeCallOfFunc, "see call to '$0'")
DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition")
DIAGNOSTIC(
-1,
@@ -2309,7 +2310,7 @@ DIAGNOSTIC(
41402,
Error,
staticAssertionConditionNotConstant,
- "condition for static assertion cannot be evaluated at the compile-time.")
+ "condition for static assertion cannot be evaluated at compile time.")
DIAGNOSTIC(
41402,
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index e20a4a90f..847c5b55c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -475,6 +475,31 @@ void calcRequiredLoweringPassSet(
}
}
+void diagnoseCallStack(IRInst* inst, DiagnosticSink* sink)
+{
+ static const int maxDepth = 5;
+ for (int i = 0; i < maxDepth; i++)
+ {
+ auto func = getParentFunc(inst);
+ if (!func)
+ return;
+ bool shouldContinue = false;
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (auto call = as<IRCall>(user))
+ {
+ sink->diagnose(call, Diagnostics::seeCallOfFunc, func);
+ inst = call;
+ shouldContinue = true;
+ break;
+ }
+ }
+ if (!shouldContinue)
+ return;
+ }
+}
+
bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
switch (inst->getOp())
@@ -498,6 +523,7 @@ bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
sink->diagnose(inst, Diagnostics::staticAssertionFailureWithoutMessage);
}
+ diagnoseCallStack(inst, sink);
}
}
else
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 6fc7d56ad..d8500a694 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -2012,6 +2012,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeArrayFromElement:
case kIROp_MakeTuple:
case kIROp_MakeValuePack:
+ case kIROp_BuiltinCast:
return transcribeConstruct(builder, origInst);
case kIROp_MakeStruct:
return transcribeMakeStruct(builder, origInst);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 9075002e0..40dcb1b51 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1026,12 +1026,9 @@ IRInst* AutoDiffSharedContext::findDifferentiableInterface()
{
for (auto globalInst : module->getGlobalInsts())
{
- // TODO: This seems like a particularly dangerous way to look for an interface.
- // See if we can lower IDifferentiable to a separate IR inst.
- //
if (auto intf = as<IRInterfaceType>(globalInst))
{
- if (auto decor = intf->findDecoration<IRNameHintDecoration>())
+ if (auto decor = intf->findDecoration<IRKnownBuiltinDecoration>())
{
if (decor->getName() == toSlice("IDifferentiable"))
{
@@ -1261,7 +1258,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness());
-
+#if 0
// TODO: Is this really needed?
if (!as<IRInterfaceType>(item->getBaseType()) &&
!as<IRAssociatedType>(item->getBaseType()))
@@ -1314,6 +1311,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
addTypeToDictionary((IRType*)diffType, diffWitness);
}
}
+#endif
}
}
}
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 53bbbda9e..125ab71d0 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -52,15 +52,28 @@ struct IRSharedSpecContext
// A map from mangled symbol names to zero or
// more global IR values that have that name,
// in the *original* module.
- typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary;
+ typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary;
SymbolDictionary symbols;
+ Dictionary<ImmutableHashedString, bool> isImportedSymbol;
+
+ bool useAutodiff = false;
+
IRBuilder builderStorage;
// The "global" specialization environment.
IRSpecEnv globalEnv;
};
+void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv);
+
+struct WitnessTableCloneInfo : RefObject
+{
+ IRWitnessTable* clonedTable;
+ IRWitnessTable* originalTable;
+ Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries;
+};
+
struct IRSpecContextBase
{
IRSharedSpecContext* shared;
@@ -69,7 +82,27 @@ struct IRSpecContextBase
IRModule* getModule() { return getShared()->module; }
- IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; }
+ List<IRModule*> irModules;
+
+ HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys;
+ List<RefPtr<WitnessTableCloneInfo>> witnessTables;
+
+ IRSpecSymbol* findSymbols(UnownedStringSlice mangledName)
+ {
+ ImmutableHashedString hashedName(mangledName);
+ RefPtr<IRSpecSymbol> symbol;
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ for (auto m : irModules)
+ {
+ for (auto inst : m->findSymbolByMangledName(hashedName))
+ insertGlobalValueSymbol(shared, inst);
+ }
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ shared->symbols[hashedName] = nullptr;
+ return nullptr;
+ }
// The current specialization environment to use.
IRSpecEnv* env = nullptr;
@@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst
// an `Add()` call.
//
context->getEnv()->clonedValues[originalValue] = clonedValue;
+
+ switch (clonedValue->getOp())
+ {
+ case kIROp_LookupWitness:
+
+ // If `originalVal` represents a witness table entry key, add the key
+ // to witnessTableEntryWorkList.
+ context->deferredWitnessTableEntryKeys.add(
+ getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey()));
+ break;
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ if (context->getShared()->useAutodiff)
+ {
+ if (auto key = as<IRStructKey>(clonedValue->getOperand(0)))
+ context->deferredWitnessTableEntryKeys.add(getMangledName(key));
+ }
+ break;
+ }
}
// Information on values to use when registering a cloned value
@@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin
return cloneInst(context, builder, originalInst, originalInst);
}
+bool isAutoDiffDecoration(IRInst* decor)
+{
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
+ case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ default:
+ return false;
+ }
+}
+
/// Clone any decorations from `originalValue` onto `clonedValue`
void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue)
{
@@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o
SLANG_UNUSED(context);
for (auto originalDecoration : originalValue->getDecorations())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration))
+ continue;
cloneInst(context, builder, originalDecoration);
}
@@ -185,6 +262,8 @@ void cloneDecorationsAndChildren(
SLANG_UNUSED(context);
for (auto originalItem : originalValue->getDecorationsAndChildren())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem))
+ continue;
cloneInst(context, builder, originalItem);
}
@@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
cloneDecorationsAndChildren(this, clonedValue, originalValue);
addHoistableInst(builder, clonedValue);
-
return clonedValue;
}
break;
@@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst(
{
default:
break;
-
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ if (!context->getShared()->useAutodiff)
+ break;
+ [[fallthrough]];
case kIROp_HLSLExportDecoration:
case kIROp_BindExistentialSlotsDecoration:
case kIROp_LayoutDecoration:
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
- case kIROp_ForwardDerivativeDecoration:
- case kIROp_UserDefinedBackwardDerivativeDecoration:
- case kIROp_PrimalSubstituteDecoration:
case kIROp_IntrinsicOpDecoration:
case kIROp_NonCopyableTypeDecoration:
case kIROp_DynamicDispatchWitnessDecoration:
@@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl(
return clonedVal;
}
+bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table)
+{
+ for (auto decor : table->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_HLSLExportDecoration:
+ case kIROp_KeepAliveDecoration:
+ return true;
+ }
+ }
+
+ auto conformanceType = getResolvedInstForDecorations(table->getConformanceType());
+
+ for (auto decor : conformanceType->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ComInterfaceDecoration:
+ return true;
+ case kIROp_KnownBuiltinDecoration:
+ {
+ auto name = as<IRKnownBuiltinDecoration>(decor)->getName();
+ if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr"))
+ return context->getShared()->useAutodiff;
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
IRWitnessTable* cloneWitnessTableImpl(
IRSpecContextBase* context,
@@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl(
bool registerValue = true)
{
IRWitnessTable* clonedTable = dstTable;
+ IRType* clonedBaseType = nullptr;
if (!clonedTable)
{
- auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
+ clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
}
- cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
+ else
+ {
+ clonedBaseType = (IRType*)clonedTable->getConformanceType();
+ }
+ if (registerValue)
+ registerClonedValue(context, clonedTable, originalValues);
+
+ // Set up an IR builder for inserting into the witness table
+ IRBuilder builderStorage = *context->builder;
+ IRBuilder* entryBuilder = &builderStorage;
+ entryBuilder->setInsertInto(clonedTable);
+
+ // Clone decorations first
+ for (auto decoration : originalTable->getDecorations())
+ {
+ cloneInst(context, entryBuilder, decoration);
+ }
+ cloneExtraDecorations(context, clonedTable, originalValues);
+
+ RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo();
+ witnessInfo->clonedTable = clonedTable;
+ witnessInfo->originalTable = originalTable;
+
+ bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable);
+
+ // Clone only the witness table entries that are actually used
+ for (auto child : originalTable->getDecorationsAndChildren())
+ {
+ if (auto entry = as<IRWitnessTableEntry>(child))
+ {
+ if (!shouldDeepClone)
+ {
+ // Skip witness table entries during the first pass,
+ // and just add them to the deferred work list.
+ witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry);
+ continue;
+ }
+ }
+ // Clone any non-entry children as is
+ cloneInst(context, entryBuilder, child);
+ }
+ context->witnessTables.add(witnessInfo);
+
return clonedTable;
}
@@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon(
}
else
{
+ if (oi->getOp() == kIROp_DifferentiableTypeAnnotation &&
+ !context->getShared()->useAutodiff)
+ continue;
cloneInst(context, builder, oi);
}
}
@@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint(
// so that the mangled name of the decl-ref is
// not the same as the mangled name of the decl.
//
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice());
+ if (!sym)
{
String hashedName = getHashedName(mangledName.getUnownedSlice());
-
- if (!context->getSymbols().tryGetValue(hashedName, sym))
+ sym = context->findSymbols(hashedName.getUnownedSlice());
+ if (!sym)
{
SLANG_UNEXPECTED("no matching IR symbol");
return nullptr;
@@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint(
if (!clonedFunc)
{
SLANG_UNEXPECTED("expected entry point to be a function");
- return nullptr;
+ UNREACHABLE_RETURN(nullptr);
}
if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context)
return context->getShared()->targetReq->getTargetCaps();
}
-/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`.
+/// Get the most appropriate ("best") capability requirements for `inVal` based on the
+/// `targetCaps`.
static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps)
{
IRInst* val = getResolvedInstForDecorations(inVal);
@@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage(
// with the same mangled name as `originalVal` and try
// to pick the "best" one for our target.
- auto mangledName = String(originalLinkage->getMangledName());
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ auto mangledName = originalLinkage->getMangledName();
+ IRSpecSymbol* sym = context->findSymbols(mangledName);
+ if (!sym)
{
if (!originalVal)
return nullptr;
@@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv)
{
sharedContext->symbols.add(mangledName, sym);
}
+
+ if (as<IRImportDecoration>(linkage))
+ sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true);
+ else
+ sharedContext->isImportedSymbol.set(mangledName, false);
}
void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule)
@@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer(
IRSpecContext* context,
Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted)
{
- // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed
- // and to allow translation of atomic_uint into SPIRV
+ // Atomic_uint definitions needs to become a storage buffer to follow
+ // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV
IRBuilder builder = *context->builder;
@@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer(
instToSwitch->setFullType(storageBuffer);
// All references to a atomic_uint need to be an element ref. to emulate storage buffer
- // usage All function calls must be inlined since storage buffers cannot pass as parameters
- // to atomic methods
+ // usage All function calls must be inlined since storage buffers cannot pass as
+ // parameters to atomic methods
for (auto& i : bindingToInstList.second)
{
int64_t currOffset =
@@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
{
case kIROp_GLSLAtomicUintType:
{
- // atomic_uint are supported by GLSL->VK through converting to a different type
- // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK;
- // this means that to get SPIR-V to work we must convert the type ourselves to
- // an equivlent representation (storage buffer); the added benifit is that then
- // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL
- // concept, but storageBuffer->RWBuffer is and HLSL concept
+ // atomic_uint are supported by GLSL->VK through converting to a different
+ // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by
+ // SPIR-V->VK; this means that to get SPIR-V to work we must convert the
+ // type ourselves to an equivlent representation (storage buffer); the added
+ // benifit is that then HLSL is possible to emit as a target as well since
+ // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and
+ // HLSL concept
auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout();
auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1));
assert(layoutVal != nullptr);
@@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted);
}
+bool isDiffPairType(IRInst* type)
+{
+ for (;;)
+ {
+ auto type1 = (IRType*)unwrapAttributedType(type);
+ auto type2 = unwrapArray(type1);
+ if (type2 == type)
+ break;
+ type = type2;
+ }
+ return as<IRDifferentialPairTypeBase>(type) != nullptr;
+}
+
+bool doesModuleUseAutodiff(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ if (auto callee = getResolvedInstForDecorations(inst->getOperand(0)))
+ {
+ switch (callee->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ return true;
+ }
+ }
+ return false;
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPtrPairGetPrimal:
+ case kIROp_DifferentialPtrPairGetDifferential:
+ return true;
+ case kIROp_StructField:
+ return isDiffPairType(as<IRStructField>(inst)->getFieldType());
+ case kIROp_Param:
+ return isDiffPairType(inst->getDataType());
+ default:
+ for (auto child : inst->getChildren())
+ {
+ bool isImported = false;
+ for (auto decor : child->getDecorations())
+ {
+ if (as<IRImportDecoration>(decor))
+ {
+ isImported = true;
+ break;
+ }
+ else if (as<IRAutoPyBindCudaDecoration>(decor))
+ {
+ return true;
+ }
+ else if (as<IRAutoPyBindExportInfoDecoration>(decor))
+ {
+ return true;
+ }
+ }
+ if (isImported)
+ continue;
+ for (auto decor : child->getDecorations())
+ {
+ if (isAutoDiffDecoration(decor))
+ return true;
+ }
+ if (doesModuleUseAutodiff(child))
+ return true;
+ }
+ return false;
+ }
+}
+
+void cloneUsedWitnessTableEntries(IRSpecContext* context)
+{
+ bool changed = true;
+ while (changed)
+ {
+ changed = false;
+ for (Index i = 0; i < context->witnessTables.getCount(); i++)
+ {
+ auto table = context->witnessTables[i].get();
+ ShortList<UnownedStringSlice> entriesToRemove;
+ for (auto entry : table->deferredEntries)
+ {
+ if (context->deferredWitnessTableEntryKeys.contains(entry.first))
+ {
+ IRBuilder builder(table->clonedTable);
+ builder.setInsertInto(table->clonedTable);
+ auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount();
+ cloneInst(context, &builder, entry.second);
+ entriesToRemove.add(entry.first);
+ if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount())
+ changed = true;
+ }
+ }
+ for (auto entry : entriesToRemove)
+ {
+ table->deferredEntries.remove(entry);
+ }
+ }
+ }
+}
+
LinkedIR linkIR(CodeGenContext* codeGenContext)
{
SLANG_PROFILE;
@@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
state->target = target;
state->targetReq = targetReq;
-
+ auto& irModules = stateStorage.contextStorage.irModules;
auto sharedContext = state->getSharedContext();
initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq);
state->irModule = sharedContext->module;
-
// We need to be able to look up IR definitions for any symbols in
// modules that the program depends on (transitively). To
// accelerate lookup, we will create a symbol table for looking
// up IR definitions by their mangled name.
//
-
- List<IRModule*> irModules;
-
- // Link the core modules.
- auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules;
- for (auto& m : coreModules)
- irModules.add(m->getIRModule());
+ auto globalSession = static_cast<Session*>(linkage->getGlobalSession());
+ List<IRModule*> builtinModules;
+ for (auto& m : globalSession->coreModules)
+ builtinModules.add(m->getIRModule());
// Link modules in the program.
- program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); });
-
- // Add any modules that were loaded as libraries
- for (IRModule* irModule : irModules)
- {
- insertGlobalValueSymbols(sharedContext, irModule);
- }
+ program->enumerateIRModules(
+ [&](IRModule* module)
+ {
+ if (module->getName() == globalSession->glslModuleName)
+ builtinModules.add(module);
+ else
+ irModules.add(module);
+ });
- // We will also insert the IR global symbols from the IR module
+ // We will also consider the IR global symbols from the IR module
// attached to the `TargetProgram`, since this module is
// responsible for associating layout information to those
// global symbols via decorations.
//
auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout();
- insertGlobalValueSymbols(sharedContext, irModuleForLayout);
+ if (irModuleForLayout)
+ irModules.add(irModuleForLayout);
+
+ Index userModuleCount = irModules.getCount();
+ irModules.addRange(builtinModules);
+ ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount);
+
+ // Check if any user module uses auto-diff, if so we will need to link
+ // additional witnesses and decorations.
+ for (IRModule* irModule : userModules)
+ {
+ if (sharedContext->useAutodiff)
+ break;
+ sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst());
+ }
auto context = state->getContext();
// Combine all of the contents of IRGlobalHashedStringLiterals
{
StringSlicePool pool(StringSlicePool::Style::Empty);
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
findGlobalHashedStringLiterals(irModule, pool);
}
@@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// instructions in all the input modules.
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto inst : irModule->getGlobalInsts())
{
@@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// We need to copy over exported symbols,
// and any global parameters if preserve-params option is set.
if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) ||
- as<IRDifferentiableTypeAnnotation>(inst))
+ sharedContext->useAutodiff &&
+ (as<IRDifferentiableTypeAnnotation>(inst) ||
+ inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr))
{
auto cloned = cloneValue(context, inst);
if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
}
}
+ // In previous steps, we have skipped cloning the witness table entries, and
+ // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys
+ // for on-demand cloning. Now we will use the deferred keys to clone the witness table
+ // entries that are referenced.
+ cloneUsedWitnessTableEntries(context);
+
// It is possible that metadata has been attached to the input modules
// themselves, which should be copied over to the output module.
//
@@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// `[assumedWaveSize(...)]` decoration might require that all specified
// values match exactly).
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto decoration : irModule->getModuleInst()->getDecorations())
{
diff --git a/source/slang/slang-ir-strip.cpp b/source/slang/slang-ir-strip.cpp
index e9fbbceb6..6c64bfca2 100644
--- a/source/slang/slang-ir-strip.cpp
+++ b/source/slang/slang-ir-strip.cpp
@@ -51,4 +51,34 @@ void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& optio
_stripFrontEndOnlyInstructionsRec(module->getModuleInst(), options);
}
+void stripImportedWitnessTable(IRModule* module)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ auto inst = globalInst;
+ switch (globalInst->getOp())
+ {
+ case kIROp_Generic:
+ inst = findInnerMostGenericReturnVal(as<IRGeneric>(globalInst));
+ break;
+ case kIROp_WitnessTable:
+ break;
+ default:
+ continue;
+ }
+ if (inst->getOp() != kIROp_WitnessTable)
+ continue;
+ if (!globalInst->findDecoration<IRImportDecoration>())
+ continue;
+ IRInst* nextChild = nullptr;
+ for (auto child = inst->getFirstChild(); child;)
+ {
+ nextChild = child->getNextInst();
+ if (child->getOp() == kIROp_WitnessTable)
+ child->removeAndDeallocate();
+ child = nextChild;
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-strip.h b/source/slang/slang-ir-strip.h
index 9d97669a1..f6f53aeda 100644
--- a/source/slang/slang-ir-strip.h
+++ b/source/slang/slang-ir-strip.h
@@ -13,5 +13,9 @@ struct IRStripOptions
/// Strip out instructions that should only be used by the front-end.
void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& options);
+
+/// Strip witness table entries from imported witness tables.
+void stripImportedWitnessTable(IRModule* module);
+
} // namespace Slang
#pragma once
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index a3cf28a68..dbd6ac099 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2175,4 +2175,14 @@ void legalizeDefUse(IRGlobalValueWithCode* func)
}
}
+UnownedStringSlice getMangledName(IRInst* inst)
+{
+ for (auto decor : inst->getDecorations())
+ {
+ if (auto linkageDecor = as<IRLinkageDecoration>(decor))
+ return linkageDecor->getMangledName();
+ }
+ return UnownedStringSlice();
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 46e0105f5..610524754 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -377,6 +377,8 @@ Int getSpecializationConstantId(IRGlobalParam* param);
void legalizeDefUse(IRGlobalValueWithCode* func);
+UnownedStringSlice getMangledName(IRInst* inst);
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f28f61ffc..fb274c4a0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4538,6 +4538,18 @@ RefPtr<IRModule> IRModule::create(Session* session)
return module;
}
+void IRModule::buildMangledNameToGlobalInstMap()
+{
+ m_mapMangledNameToGlobalInst.clear();
+ for (auto inst : getGlobalInsts())
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ {
+ m_mapMangledNameToGlobalInst[linkageDecor->getMangledName()].add(inst);
+ }
+ }
+}
+
IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func)
{
IRAnalysis* analysis = m_mapInstToAnalysis.tryGetValue(func);
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 8f53c9f14..ecf5d1c66 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -2374,6 +2374,15 @@ public:
m_obfuscatedSourceMap = sourceMap;
}
+ ArrayView<IRInst*> findSymbolByMangledName(const ImmutableHashedString& mangledName) const
+ {
+ if (auto list = m_mapMangledNameToGlobalInst.tryGetValue(mangledName))
+ return list->getArrayView();
+ return {};
+ }
+
+ void buildMangledNameToGlobalInstMap();
+
IRDeduplicationContext* getDeduplicationContext() const { return &m_deduplicationContext; }
IRDominatorTree* findDominatorTree(IRGlobalValueWithCode* func)
@@ -2392,6 +2401,9 @@ public:
IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); }
+ Name* getName() const { return m_name; }
+ void setName(Name* name) { m_name = name; }
+
/// Create an empty instruction with the `op` opcode and space for
/// a number of operands given by `operandCount`.
///
@@ -2444,6 +2456,9 @@ private:
///
IRModuleInst* m_moduleInst = nullptr;
+ // The name of the module.
+ Name* m_name = nullptr;
+
/// The memory arena from which all IR instructions (and any associated state) in this module
/// are allocated.
MemoryArena m_memoryArena;
@@ -2459,6 +2474,8 @@ private:
ComPtr<IBoxValue<SourceMap>> m_obfuscatedSourceMap;
Dictionary<IRInst*, IRAnalysis> m_mapInstToAnalysis;
+
+ Dictionary<ImmutableHashedString, List<IRInst*>> m_mapMangledNameToGlobalInst;
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index fbe6d8a84..e5037bf04 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8168,24 +8168,34 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// If the witness table is for a COM interface, always keep it alive.
if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
{
- subBuilder->addPublicDecoration(irWitnessTable);
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
}
- if (parentDecl->findModifier<HLSLExportModifier>())
+ for (auto mod : parentDecl->modifiers)
{
- subBuilder->addHLSLExportDecoration(irWitnessTable);
- subBuilder->addKeepAliveDecoration(irWitnessTable);
+ if (as<HLSLExportModifier>(mod))
+ {
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
+ subBuilder->addKeepAliveDecoration(irWitnessTable);
+ }
+ else if (as<AutoDiffBuiltinAttribute>(mod))
+ {
+ subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ }
}
// Make sure that all the entries in the witness table have been filled in,
// including any cases where there are sub-witness-tables for conformances
- Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
- lowerWitnessTable(
- subContext,
- inheritanceDecl->witnessTable,
- irWitnessTable,
- mapASTToIRWitnessTable);
-
+ bool isExplicitExtern = false;
+ if (!isImportedDecl(context, parentDecl, isExplicitExtern))
+ {
+ Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
+ lowerWitnessTable(
+ subContext,
+ inheritanceDecl->witnessTable,
+ irWitnessTable,
+ mapASTToIRWitnessTable);
+ }
irWitnessTable->moveToEnd();
return LoweredValInfo::simple(
@@ -11536,6 +11546,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
RefPtr<IRModule> module = IRModule::create(session);
+ module->setName(translationUnit->getModuleDecl()->getName());
+
IRBuilder builderStorage(module);
IRBuilder* builder = &builderStorage;
@@ -11804,6 +11816,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
stripOptions.stripSourceLocs = false;
stripFrontEndOnlyInstructions(module, stripOptions);
+ stripImportedWitnessTable(module);
+
// Stripping out decorations could leave some dead code behind
// in the module, and in some cases that extra code is also
// undesirable (e.g., the string literals referenced by name-hint
@@ -11847,6 +11861,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
&writer);
}
+ module->buildMangledNameToGlobalInstMap();
+
return module;
}
@@ -11884,7 +11900,7 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
context->irBuilder = builder;
componentType->acceptVisitor(this, nullptr);
-
+ module->buildMangledNameToGlobalInstMap();
return module;
}
@@ -12040,6 +12056,7 @@ struct TypeConformanceIRGenContext
{
builder->addSequentialIDDecoration(witness, conformanceIdOverride);
}
+ module->buildMangledNameToGlobalInstMap();
return module;
}
};
@@ -12507,7 +12524,7 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink)
// Eliminate any dead code
eliminateDeadCode(irModule, options);
}
-
+ irModule->buildMangledNameToGlobalInstMap();
m_irModuleForLayout = irModule;
return irModule;
}
diff --git a/source/slang/slang-serialize-ir.cpp b/source/slang/slang-serialize-ir.cpp
index c6926735a..b8a1183b3 100644
--- a/source/slang/slang-serialize-ir.cpp
+++ b/source/slang/slang-serialize-ir.cpp
@@ -1047,6 +1047,8 @@ Result IRSerialReader::read(
}
}
+ outModule->buildMangledNameToGlobalInstMap();
+
return SLANG_OK;
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index c316974f1..ec90ee418 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -410,6 +410,11 @@ const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name)
return result;
}
+TypeCheckingCache* Session::getTypeCheckingCache()
+{
+ return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get());
+}
+
Session::BuiltinModuleInfo Session::getBuiltinModuleInfo(slang::BuiltinModuleName name)
{
Session::BuiltinModuleInfo result;
@@ -700,6 +705,7 @@ SlangResult Session::_readBuiltinModule(
module->setModuleDecl(moduleDecl);
}
+ srcModule.irModule->setName(module->getNameObj());
module->setIRModule(srcModule.irModule);
// Put in the loaded module map
@@ -803,6 +809,10 @@ Session::createSession(slang::SessionDesc const& inDesc, slang::ISession** outSe
RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage());
+ if (m_typeCheckingCache)
+ linkage->m_typeCheckingCache =
+ new TypeCheckingCache(*static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()));
+
linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode);
Int searchPathCount = desc.searchPathCount;
@@ -1263,9 +1273,6 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
, m_astBuilder(astBuilder)
, m_cmdLineContext(new CommandLineContext())
{
- if (builtinLinkage)
- m_astBuilder->m_cachedNodes = builtinLinkage->getASTBuilder()->m_cachedNodes;
-
getNamePool()->setRootNamePool(session->getRootNamePool());
m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr);
@@ -1297,6 +1304,17 @@ ISlangUnknown* Linkage::getInterface(const Guid& guid)
Linkage::~Linkage()
{
+ // Upstream type checking cache.
+ if (m_typeCheckingCache)
+ {
+ auto globalSession = getSessionImpl();
+ if (!globalSession->m_typeCheckingCache ||
+ globalSession->getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount() <
+ getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount())
+ {
+ globalSession->m_typeCheckingCache = m_typeCheckingCache;
+ }
+ }
destroyTypeCheckingCache();
}
@@ -1318,12 +1336,11 @@ TypeCheckingCache* Linkage::getTypeCheckingCache()
{
m_typeCheckingCache = new TypeCheckingCache();
}
- return m_typeCheckingCache;
+ return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get());
}
void Linkage::destroyTypeCheckingCache()
{
- delete m_typeCheckingCache;
m_typeCheckingCache = nullptr;
}
@@ -4080,6 +4097,8 @@ RefPtr<Module> Linkage::loadModuleFromIRBlobImpl(
loadedModulesList.add(resultModule);
resultModule->setPathInfo(filePathInfo);
+ resultModule->getIRModule()->setName(resultModule->getNameObj());
+
return resultModule;
}