summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/design/stdlib-intrinsics.md6
-rw-r--r--source/core/slang-string.h51
-rw-r--r--source/slang/core.meta.slang10
-rw-r--r--source/slang/diff.meta.slang2
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-check-decl.cpp6
-rw-r--r--source/slang/slang-check-impl.h33
-rw-r--r--source/slang/slang-check-overload.cpp71
-rw-r--r--source/slang/slang-compiler.h5
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit.cpp26
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-autodiff.cpp8
-rw-r--r--source/slang/slang-ir-link.cpp390
-rw-r--r--source/slang/slang-ir-strip.cpp30
-rw-r--r--source/slang/slang-ir-strip.h4
-rw-r--r--source/slang/slang-ir-util.cpp10
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-ir.h17
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
-rw-r--r--source/slang/slang-serialize-ir.cpp2
-rw-r--r--source/slang/slang.cpp29
-rw-r--r--tests/language-feature/static_assert.slang12
-rw-r--r--tools/slang-unit-test/unit-test-compile-benchmark.cpp113
25 files changed, 766 insertions, 122 deletions
diff --git a/docs/design/stdlib-intrinsics.md b/docs/design/stdlib-intrinsics.md
index a9369138d..2ea50cd54 100644
--- a/docs/design/stdlib-intrinsics.md
+++ b/docs/design/stdlib-intrinsics.md
@@ -114,12 +114,6 @@ Sections of the `expansion` string that are to be replaced are prefixed by the `
* $XH - Ray tracing hit object attribute
* $P - Type-based prefix as used for CUDA and C++ targets (I8 for int8_t, F32 - float etc)
-## __specialized_for_target(target)
-
-Specialized for target allows defining an implementation *body* for a particular target. The target is the same as is used for [__target_intrinsic](#target-intrinsic).
-
-A declaration can consist of multiple definitions with bodies (for each target) using, `specialized_for_target`, as well as having `target_intrinsic` if that is applicable for a target.
-
## __attributeTarget(astClassName)
For an attribute, specifies the AST class (and derived class) the attribute can be applied to.
diff --git a/source/core/slang-string.h b/source/core/slang-string.h
index 24b119383..3da0db6b9 100644
--- a/source/core/slang-string.h
+++ b/source/core/slang-string.h
@@ -790,6 +790,57 @@ public:
UnownedStringSlice getUnownedSlice() const { return StringRepresentation::asSlice(m_buffer); }
};
+class ImmutableHashedString
+{
+public:
+ String slice;
+ HashCode64 hashCode;
+ ImmutableHashedString()
+ : hashCode(0)
+ {
+ }
+ ImmutableHashedString(const UnownedStringSlice& slice)
+ : slice(slice), hashCode(slice.getHashCode())
+ {
+ }
+ ImmutableHashedString(const char* begin, const char* end)
+ : slice(begin, end), hashCode(slice.getHashCode())
+ {
+ }
+ ImmutableHashedString(const char* begin, size_t len)
+ : slice(UnownedStringSlice(begin, len)), hashCode(slice.getHashCode())
+ {
+ }
+ ImmutableHashedString(const char* begin)
+ : slice(begin), hashCode(slice.getHashCode())
+ {
+ }
+ ImmutableHashedString(const String& str)
+ : slice(str), hashCode(str.getHashCode())
+ {
+ }
+ ImmutableHashedString(String&& str)
+ : slice(_Move(str)), hashCode(str.getHashCode())
+ {
+ }
+ ImmutableHashedString(const ImmutableHashedString& other) = default;
+ ImmutableHashedString& operator=(const ImmutableHashedString& other) = default;
+ bool operator==(const ImmutableHashedString& other) const
+ {
+ return hashCode == other.hashCode && slice == other.slice;
+ }
+ bool operator!=(const ImmutableHashedString& other) const
+ {
+ return hashCode != other.hashCode || slice != other.slice;
+ }
+ bool operator==(const UnownedStringSlice& other) const { return slice == other; }
+ bool operator!=(const UnownedStringSlice& other) const { return slice != other; }
+ bool operator==(const String& other) const { return slice == other.getUnownedSlice(); }
+ bool operator!=(const String& other) const { return slice != other.getUnownedSlice(); }
+ bool operator==(const char* other) const { return slice == UnownedStringSlice(other); }
+ HashCode64 getHashCode() const { return hashCode; }
+};
+
class SLANG_RT_API StringBuilder : public String
{
private:
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 267f7b2d4..da1b47e13 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -603,6 +603,7 @@ void static_assert(constexpr bool condition, NativeString errorMessage);
///
///
__magic_type(DifferentiableType)
+[KnownBuiltin("IDifferentiable")]
interface IDifferentiable
{
// Note: the compiler implementation requires the `Differential` associated type to be defined
@@ -645,6 +646,7 @@ interface IDifferentiable
/// @remarks Support for this interface is still experimental and subject to change.
///
__magic_type(DifferentiablePtrType)
+[KnownBuiltin("IDifferentiablePtr")]
interface IDifferentiablePtrType
{
__builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
@@ -2598,20 +2600,20 @@ for(auto fixity : kIncDecFixities)
$(fixity.qual)
__generic<T : __BuiltinArithmeticType>
[__unsafeForceInlineEarly]
-T operator$(op.name)(in out T value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+T operator$(op.name)( in out T value)
+{ $(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let N : int>
[__unsafeForceInlineEarly]
vector<T,N> operator$(op.name)(in out vector<T,N> value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int>
[__unsafeForceInlineEarly]
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
-{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
+{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
__generic<T, let addrSpace : uint64_t>
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 790dfaa79..d32584db0 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -176,7 +176,7 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
// for internal use.
//
[__AutoDiffBuiltin]
-export struct NullDifferential : IDifferentiable
+struct NullDifferential : IDifferentiable
{
// for now, we'll use at least one field to make sure the type is non-empty
uint dummy;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 173b97fd3..78eb48a33 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -20309,7 +20309,7 @@ void ReorderThread( HitObject HitOrMiss )
///
/// There doesn't appear to be an equivalent for debugBreak for HLSL
-
+[require(glsl)]
__specialized_for_target(glsl)
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
void __glslDebugBreak();
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5b5e05b73..0c42817c8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10452,6 +10452,12 @@ void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDec
{
return;
}
+
+ if (getText(moduleDecl->getName()) == "glsl")
+ {
+ getShared()->glslModuleDecl = moduleDecl;
+ }
+
importedModulesList.add(moduleDecl);
importedModulesSet.add(moduleDecl);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 6438a91e3..edb199299 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -172,15 +172,17 @@ struct BasicTypeKeyPair
struct OperatorOverloadCacheKey
{
- intptr_t operatorName;
+ int32_t operatorName;
+ bool isGLSLMode;
BasicTypeKey args[2];
bool operator==(OperatorOverloadCacheKey key) const
{
- return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1];
+ return operatorName == key.operatorName && args[0] == key.args[0] &&
+ args[1] == key.args[1] && isGLSLMode == key.isGLSLMode;
}
HashCode getHashCode() const
{
- return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw());
+ return combineHash(operatorName, args[0].getRaw(), args[1].getRaw(), isGLSLMode ? 1 : 0);
}
bool fromOperatorExpr(OperatorExpr* opExpr)
{
@@ -299,10 +301,28 @@ struct OverloadCandidate
SubstitutionSet subst;
};
-struct TypeCheckingCache
+struct ResolvedOperatorOverload
{
- Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
+ // The resolved decl.
+ Decl* decl;
+
+ // The cached overload candidate in the current TypeCheckingCache.
+ // Note that a `OverloadCandidate` object is not migratable over different
+ // Linkages (compile sessions), so we will need to use `cacheVersion` to track
+ // if this `candidate` is valid for the current session. If not, we will
+ // recreate it from `decl`.
+ OverloadCandidate candidate;
+ // The version of the TypeCheckingCache for which the cached candidate is valid.
+ int cacheVersion;
+};
+
+struct TypeCheckingCache : public RefObject
+{
+ Dictionary<OperatorOverloadCacheKey, ResolvedOperatorOverload> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
+
+ // The version used to invalidate the cached declRefs in ResolvedOperatorOverload entries.
+ int version = 0;
};
enum class CoercionSite
@@ -635,6 +655,9 @@ struct SharedSemanticsContext : public RefObject
DiagnosticSink* m_sink = nullptr;
+ // Whether the current module has imported the GLSL module.
+ ModuleDecl* glslModuleDecl = nullptr;
+
/// (optional) modules that comes from previously processed translation units in the
/// front-end request that are made visible to the module being checked. This allows
/// `import` to use them instead of trying to find the files in file system.
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index b75f95f9a..f548fb819 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -2509,27 +2509,6 @@ String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context)
Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
{
OverloadResolveContext context;
- // check if this is a core module operator call, if so we want to use cached results
- // to speed up compilation
- bool shouldAddToCache = false;
- OperatorOverloadCacheKey key;
- TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
- if (auto opExpr = as<OperatorExpr>(expr))
- {
- if (key.fromOperatorExpr(opExpr))
- {
- OverloadCandidate candidate;
- if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
- {
- context.bestCandidateStorage = candidate;
- context.bestCandidate = &context.bestCandidateStorage;
- }
- else
- {
- shouldAddToCache = true;
- }
- }
- }
// Look at the base expression for the call, and figure out how to invoke it.
auto funcExpr = expr->functionExpr;
@@ -2569,6 +2548,43 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
context.loc = expr->loc;
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(funcExpr);
+
+ // check if this is a core module operator call, if so we want to use cached results
+ // to speed up compilation
+ bool shouldAddToCache = false;
+ OperatorOverloadCacheKey key;
+ TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
+ if (auto opExpr = as<OperatorExpr>(expr))
+ {
+ if (key.fromOperatorExpr(opExpr))
+ {
+ key.isGLSLMode = getShared()->glslModuleDecl != nullptr;
+ ResolvedOperatorOverload candidate;
+ if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
+ {
+ // We should only use the cached candidate if it is persistent direct declref
+ // created from GlobalSession's ASTBuilder, or it is created in the current Linkage.
+ if (candidate.cacheVersion == typeCheckingCache->version ||
+ as<DirectDeclRef>(candidate.candidate.item.declRef.declRefBase))
+ {
+ context.bestCandidateStorage = candidate.candidate;
+ context.bestCandidate = &context.bestCandidateStorage;
+ }
+ else
+ {
+ LookupResultItem overloadCandidate = {};
+ overloadCandidate.declRef = getOuterGenericOrSelf(candidate.decl);
+ AddDeclRefOverloadCandidates(overloadCandidate, context, 0);
+ shouldAddToCache = true;
+ }
+ }
+ else
+ {
+ shouldAddToCache = true;
+ }
+ }
+ }
+
// We run a special case here where an `InvokeExpr`
// with a single argument where the base/func expression names
// a type should always be treated as an explicit type coercion
@@ -2731,7 +2747,18 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
// We will report errors for this one candidate, then, to give
// the user the most help we can.
if (shouldAddToCache)
- typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;
+ {
+ if (isFromCoreModule(context.bestCandidate->item.declRef.getDecl()) ||
+ getShared()->glslModuleDecl ==
+ getModuleDecl(context.bestCandidate->item.declRef.getDecl()))
+ {
+ ResolvedOperatorOverload overloadResult;
+ overloadResult.candidate = *context.bestCandidate;
+ overloadResult.decl = context.bestCandidate->item.declRef.getDecl();
+ overloadResult.cacheVersion = typeCheckingCache->version;
+ typeCheckingCache->resolvedOperatorOverloadCache[key] = overloadResult;
+ }
+ }
// Now that we have resolved the overload candidate, we need to undo an `openExistential`
// operation that was applied to `out` arguments.
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 38725fff3..3e00e4b04 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2209,7 +2209,7 @@ public:
TypeCheckingCache* getTypeCheckingCache();
void destroyTypeCheckingCache();
- TypeCheckingCache* m_typeCheckingCache = nullptr;
+ RefPtr<RefObject> m_typeCheckingCache = nullptr;
// Modules that have been dynamically loaded via `import`
//
@@ -3589,6 +3589,9 @@ public:
int m_typeDictionarySize = 0;
+ RefPtr<RefObject> m_typeCheckingCache;
+ TypeCheckingCache* getTypeCheckingCache();
+
private:
struct BuiltinModuleInfo
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index cf653d001..03f52a701 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -49,6 +49,7 @@ DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'")
DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'")
DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'")
DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'")
+DIAGNOSTIC(-1, Note, seeCallOfFunc, "see call to '$0'")
DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition")
DIAGNOSTIC(
-1,
@@ -2309,7 +2310,7 @@ DIAGNOSTIC(
41402,
Error,
staticAssertionConditionNotConstant,
- "condition for static assertion cannot be evaluated at the compile-time.")
+ "condition for static assertion cannot be evaluated at compile time.")
DIAGNOSTIC(
41402,
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index e20a4a90f..847c5b55c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -475,6 +475,31 @@ void calcRequiredLoweringPassSet(
}
}
+void diagnoseCallStack(IRInst* inst, DiagnosticSink* sink)
+{
+ static const int maxDepth = 5;
+ for (int i = 0; i < maxDepth; i++)
+ {
+ auto func = getParentFunc(inst);
+ if (!func)
+ return;
+ bool shouldContinue = false;
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (auto call = as<IRCall>(user))
+ {
+ sink->diagnose(call, Diagnostics::seeCallOfFunc, func);
+ inst = call;
+ shouldContinue = true;
+ break;
+ }
+ }
+ if (!shouldContinue)
+ return;
+ }
+}
+
bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
switch (inst->getOp())
@@ -498,6 +523,7 @@ bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
sink->diagnose(inst, Diagnostics::staticAssertionFailureWithoutMessage);
}
+ diagnoseCallStack(inst, sink);
}
}
else
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 6fc7d56ad..d8500a694 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -2012,6 +2012,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeArrayFromElement:
case kIROp_MakeTuple:
case kIROp_MakeValuePack:
+ case kIROp_BuiltinCast:
return transcribeConstruct(builder, origInst);
case kIROp_MakeStruct:
return transcribeMakeStruct(builder, origInst);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 9075002e0..40dcb1b51 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1026,12 +1026,9 @@ IRInst* AutoDiffSharedContext::findDifferentiableInterface()
{
for (auto globalInst : module->getGlobalInsts())
{
- // TODO: This seems like a particularly dangerous way to look for an interface.
- // See if we can lower IDifferentiable to a separate IR inst.
- //
if (auto intf = as<IRInterfaceType>(globalInst))
{
- if (auto decor = intf->findDecoration<IRNameHintDecoration>())
+ if (auto decor = intf->findDecoration<IRKnownBuiltinDecoration>())
{
if (decor->getName() == toSlice("IDifferentiable"))
{
@@ -1261,7 +1258,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness());
-
+#if 0
// TODO: Is this really needed?
if (!as<IRInterfaceType>(item->getBaseType()) &&
!as<IRAssociatedType>(item->getBaseType()))
@@ -1314,6 +1311,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
addTypeToDictionary((IRType*)diffType, diffWitness);
}
}
+#endif
}
}
}
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 53bbbda9e..125ab71d0 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -52,15 +52,28 @@ struct IRSharedSpecContext
// A map from mangled symbol names to zero or
// more global IR values that have that name,
// in the *original* module.
- typedef Dictionary<String, RefPtr<IRSpecSymbol>> SymbolDictionary;
+ typedef Dictionary<ImmutableHashedString, RefPtr<IRSpecSymbol>> SymbolDictionary;
SymbolDictionary symbols;
+ Dictionary<ImmutableHashedString, bool> isImportedSymbol;
+
+ bool useAutodiff = false;
+
IRBuilder builderStorage;
// The "global" specialization environment.
IRSpecEnv globalEnv;
};
+void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv);
+
+struct WitnessTableCloneInfo : RefObject
+{
+ IRWitnessTable* clonedTable;
+ IRWitnessTable* originalTable;
+ Dictionary<UnownedStringSlice, IRWitnessTableEntry*> deferredEntries;
+};
+
struct IRSpecContextBase
{
IRSharedSpecContext* shared;
@@ -69,7 +82,27 @@ struct IRSpecContextBase
IRModule* getModule() { return getShared()->module; }
- IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; }
+ List<IRModule*> irModules;
+
+ HashSet<UnownedStringSlice> deferredWitnessTableEntryKeys;
+ List<RefPtr<WitnessTableCloneInfo>> witnessTables;
+
+ IRSpecSymbol* findSymbols(UnownedStringSlice mangledName)
+ {
+ ImmutableHashedString hashedName(mangledName);
+ RefPtr<IRSpecSymbol> symbol;
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ for (auto m : irModules)
+ {
+ for (auto inst : m->findSymbolByMangledName(hashedName))
+ insertGlobalValueSymbol(shared, inst);
+ }
+ if (shared->symbols.tryGetValue(hashedName, symbol))
+ return symbol;
+ shared->symbols[hashedName] = nullptr;
+ return nullptr;
+ }
// The current specialization environment to use.
IRSpecEnv* env = nullptr;
@@ -105,6 +138,26 @@ void registerClonedValue(IRSpecContextBase* context, IRInst* clonedValue, IRInst
// an `Add()` call.
//
context->getEnv()->clonedValues[originalValue] = clonedValue;
+
+ switch (clonedValue->getOp())
+ {
+ case kIROp_LookupWitness:
+
+ // If `originalVal` represents a witness table entry key, add the key
+ // to witnessTableEntryWorkList.
+ context->deferredWitnessTableEntryKeys.add(
+ getMangledName(as<IRLookupWitnessMethod>(clonedValue)->getRequirementKey()));
+ break;
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ if (context->getShared()->useAutodiff)
+ {
+ if (auto key = as<IRStructKey>(clonedValue->getOperand(0)))
+ context->deferredWitnessTableEntryKeys.add(getMangledName(key));
+ }
+ break;
+ }
}
// Information on values to use when registering a cloned value
@@ -149,6 +202,28 @@ IRInst* cloneInst(IRSpecContextBase* context, IRBuilder* builder, IRInst* origin
return cloneInst(context, builder, originalInst, originalInst);
}
+bool isAutoDiffDecoration(IRInst* decor)
+{
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
+ case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ default:
+ return false;
+ }
+}
+
/// Clone any decorations from `originalValue` onto `clonedValue`
void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* originalValue)
{
@@ -165,6 +240,8 @@ void cloneDecorations(IRSpecContextBase* context, IRInst* clonedValue, IRInst* o
SLANG_UNUSED(context);
for (auto originalDecoration : originalValue->getDecorations())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalDecoration))
+ continue;
cloneInst(context, builder, originalDecoration);
}
@@ -185,6 +262,8 @@ void cloneDecorationsAndChildren(
SLANG_UNUSED(context);
for (auto originalItem : originalValue->getDecorationsAndChildren())
{
+ if (!context->shared->useAutodiff && isAutoDiffDecoration(originalItem))
+ continue;
cloneInst(context, builder, originalItem);
}
@@ -294,7 +373,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
cloneDecorationsAndChildren(this, clonedValue, originalValue);
addHoistableInst(builder, clonedValue);
-
return clonedValue;
}
break;
@@ -407,15 +485,17 @@ static void cloneExtraDecorationsFromInst(
{
default:
break;
-
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ if (!context->getShared()->useAutodiff)
+ break;
+ [[fallthrough]];
case kIROp_HLSLExportDecoration:
case kIROp_BindExistentialSlotsDecoration:
case kIROp_LayoutDecoration:
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
- case kIROp_ForwardDerivativeDecoration:
- case kIROp_UserDefinedBackwardDerivativeDecoration:
- case kIROp_PrimalSubstituteDecoration:
case kIROp_IntrinsicOpDecoration:
case kIROp_NonCopyableTypeDecoration:
case kIROp_DynamicDispatchWitnessDecoration:
@@ -602,6 +682,40 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl(
return clonedVal;
}
+bool shouldDeepCloneWitnessTable(IRSpecContextBase* context, IRWitnessTable* table)
+{
+ for (auto decor : table->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_HLSLExportDecoration:
+ case kIROp_KeepAliveDecoration:
+ return true;
+ }
+ }
+
+ auto conformanceType = getResolvedInstForDecorations(table->getConformanceType());
+
+ for (auto decor : conformanceType->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ComInterfaceDecoration:
+ return true;
+ case kIROp_KnownBuiltinDecoration:
+ {
+ auto name = as<IRKnownBuiltinDecoration>(decor)->getName();
+ if (name == toSlice("IDifferentiable") || name == toSlice("IDifferentiablePtr"))
+ return context->getShared()->useAutodiff;
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
IRWitnessTable* cloneWitnessTableImpl(
IRSpecContextBase* context,
@@ -612,13 +726,56 @@ IRWitnessTable* cloneWitnessTableImpl(
bool registerValue = true)
{
IRWitnessTable* clonedTable = dstTable;
+ IRType* clonedBaseType = nullptr;
if (!clonedTable)
{
- auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
+ clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
}
- cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
+ else
+ {
+ clonedBaseType = (IRType*)clonedTable->getConformanceType();
+ }
+ if (registerValue)
+ registerClonedValue(context, clonedTable, originalValues);
+
+ // Set up an IR builder for inserting into the witness table
+ IRBuilder builderStorage = *context->builder;
+ IRBuilder* entryBuilder = &builderStorage;
+ entryBuilder->setInsertInto(clonedTable);
+
+ // Clone decorations first
+ for (auto decoration : originalTable->getDecorations())
+ {
+ cloneInst(context, entryBuilder, decoration);
+ }
+ cloneExtraDecorations(context, clonedTable, originalValues);
+
+ RefPtr<WitnessTableCloneInfo> witnessInfo = new WitnessTableCloneInfo();
+ witnessInfo->clonedTable = clonedTable;
+ witnessInfo->originalTable = originalTable;
+
+ bool shouldDeepClone = shouldDeepCloneWitnessTable(context, originalTable);
+
+ // Clone only the witness table entries that are actually used
+ for (auto child : originalTable->getDecorationsAndChildren())
+ {
+ if (auto entry = as<IRWitnessTableEntry>(child))
+ {
+ if (!shouldDeepClone)
+ {
+ // Skip witness table entries during the first pass,
+ // and just add them to the deferred work list.
+ witnessInfo->deferredEntries.add(getMangledName(entry->getRequirementKey()), entry);
+ continue;
+ }
+ }
+ // Clone any non-entry children as is
+ cloneInst(context, entryBuilder, child);
+ }
+ context->witnessTables.add(witnessInfo);
+
return clonedTable;
}
@@ -740,6 +897,9 @@ void cloneGlobalValueWithCodeCommon(
}
else
{
+ if (oi->getOp() == kIROp_DifferentiableTypeAnnotation &&
+ !context->getShared()->useAutodiff)
+ continue;
cloneInst(context, builder, oi);
}
}
@@ -886,12 +1046,12 @@ IRFunc* specializeIRForEntryPoint(
// so that the mangled name of the decl-ref is
// not the same as the mangled name of the decl.
//
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ IRSpecSymbol* sym = context->findSymbols(mangledName.getUnownedSlice());
+ if (!sym)
{
String hashedName = getHashedName(mangledName.getUnownedSlice());
-
- if (!context->getSymbols().tryGetValue(hashedName, sym))
+ sym = context->findSymbols(hashedName.getUnownedSlice());
+ if (!sym)
{
SLANG_UNEXPECTED("no matching IR symbol");
return nullptr;
@@ -967,7 +1127,7 @@ IRFunc* specializeIRForEntryPoint(
if (!clonedFunc)
{
SLANG_UNEXPECTED("expected entry point to be a function");
- return nullptr;
+ UNREACHABLE_RETURN(nullptr);
}
if (!clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1022,7 +1182,8 @@ CapabilitySet getTargetCapabilities(IRSpecContext* context)
return context->getShared()->targetReq->getTargetCaps();
}
-/// Get the most appropriate ("best") capability requirements for `inVal` based on the `targetCaps`.
+/// Get the most appropriate ("best") capability requirements for `inVal` based on the
+/// `targetCaps`.
static CapabilitySet _getBestSpecializationCaps(IRInst* inVal, CapabilitySet const& targetCaps)
{
IRInst* val = getResolvedInstForDecorations(inVal);
@@ -1335,9 +1496,9 @@ IRInst* cloneGlobalValueWithLinkage(
// with the same mangled name as `originalVal` and try
// to pick the "best" one for our target.
- auto mangledName = String(originalLinkage->getMangledName());
- RefPtr<IRSpecSymbol> sym;
- if (!context->getSymbols().tryGetValue(mangledName, sym))
+ auto mangledName = originalLinkage->getMangledName();
+ IRSpecSymbol* sym = context->findSymbols(mangledName);
+ if (!sym)
{
if (!originalVal)
return nullptr;
@@ -1419,6 +1580,11 @@ void insertGlobalValueSymbol(IRSharedSpecContext* sharedContext, IRInst* gv)
{
sharedContext->symbols.add(mangledName, sym);
}
+
+ if (as<IRImportDecoration>(linkage))
+ sharedContext->isImportedSymbol.tryGetValueOrAdd(mangledName, true);
+ else
+ sharedContext->isImportedSymbol.set(mangledName, false);
}
void insertGlobalValueSymbols(IRSharedSpecContext* sharedContext, IRModule* originalModule)
@@ -1605,8 +1771,8 @@ void convertAtomicToStorageBuffer(
IRSpecContext* context,
Dictionary<int, List<IRInst*>>& bindingToInstMapUnsorted)
{
- // Atomic_uint definitions needs to become a storage buffer to follow GL_EXT_vulkan_glsl_relaxed
- // and to allow translation of atomic_uint into SPIRV
+ // Atomic_uint definitions needs to become a storage buffer to follow
+ // GL_EXT_vulkan_glsl_relaxed and to allow translation of atomic_uint into SPIRV
IRBuilder builder = *context->builder;
@@ -1653,8 +1819,8 @@ void convertAtomicToStorageBuffer(
instToSwitch->setFullType(storageBuffer);
// All references to a atomic_uint need to be an element ref. to emulate storage buffer
- // usage All function calls must be inlined since storage buffers cannot pass as parameters
- // to atomic methods
+ // usage All function calls must be inlined since storage buffers cannot pass as
+ // parameters to atomic methods
for (auto& i : bindingToInstList.second)
{
int64_t currOffset =
@@ -1722,12 +1888,13 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
{
case kIROp_GLSLAtomicUintType:
{
- // atomic_uint are supported by GLSL->VK through converting to a different type
- // (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by SPIR-V->VK;
- // this means that to get SPIR-V to work we must convert the type ourselves to
- // an equivlent representation (storage buffer); the added benifit is that then
- // HLSL is possible to emit as a target as well since atomic_uint is not an HLSL
- // concept, but storageBuffer->RWBuffer is and HLSL concept
+ // atomic_uint are supported by GLSL->VK through converting to a different
+ // type (GL_EXT_vulkan_glsl_relaxed). atomic_uint are not supported by
+ // SPIR-V->VK; this means that to get SPIR-V to work we must convert the
+ // type ourselves to an equivlent representation (storage buffer); the added
+ // benifit is that then HLSL is possible to emit as a target as well since
+ // atomic_uint is not an HLSL concept, but storageBuffer->RWBuffer is and
+ // HLSL concept
auto layout = inst->findDecoration<IRLayoutDecoration>()->getLayout();
auto layoutVal = as<IRVarOffsetAttr>(layout->getOperand(1));
assert(layoutVal != nullptr);
@@ -1743,6 +1910,110 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted);
}
+bool isDiffPairType(IRInst* type)
+{
+ for (;;)
+ {
+ auto type1 = (IRType*)unwrapAttributedType(type);
+ auto type2 = unwrapArray(type1);
+ if (type2 == type)
+ break;
+ type = type2;
+ }
+ return as<IRDifferentialPairTypeBase>(type) != nullptr;
+}
+
+bool doesModuleUseAutodiff(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ if (auto callee = getResolvedInstForDecorations(inst->getOperand(0)))
+ {
+ switch (callee->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ return true;
+ }
+ }
+ return false;
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPtrPairGetPrimal:
+ case kIROp_DifferentialPtrPairGetDifferential:
+ return true;
+ case kIROp_StructField:
+ return isDiffPairType(as<IRStructField>(inst)->getFieldType());
+ case kIROp_Param:
+ return isDiffPairType(inst->getDataType());
+ default:
+ for (auto child : inst->getChildren())
+ {
+ bool isImported = false;
+ for (auto decor : child->getDecorations())
+ {
+ if (as<IRImportDecoration>(decor))
+ {
+ isImported = true;
+ break;
+ }
+ else if (as<IRAutoPyBindCudaDecoration>(decor))
+ {
+ return true;
+ }
+ else if (as<IRAutoPyBindExportInfoDecoration>(decor))
+ {
+ return true;
+ }
+ }
+ if (isImported)
+ continue;
+ for (auto decor : child->getDecorations())
+ {
+ if (isAutoDiffDecoration(decor))
+ return true;
+ }
+ if (doesModuleUseAutodiff(child))
+ return true;
+ }
+ return false;
+ }
+}
+
+void cloneUsedWitnessTableEntries(IRSpecContext* context)
+{
+ bool changed = true;
+ while (changed)
+ {
+ changed = false;
+ for (Index i = 0; i < context->witnessTables.getCount(); i++)
+ {
+ auto table = context->witnessTables[i].get();
+ ShortList<UnownedStringSlice> entriesToRemove;
+ for (auto entry : table->deferredEntries)
+ {
+ if (context->deferredWitnessTableEntryKeys.contains(entry.first))
+ {
+ IRBuilder builder(table->clonedTable);
+ builder.setInsertInto(table->clonedTable);
+ auto deferredKeyCount = context->deferredWitnessTableEntryKeys.getCount();
+ cloneInst(context, &builder, entry.second);
+ entriesToRemove.add(entry.first);
+ if (deferredKeyCount != context->deferredWitnessTableEntryKeys.getCount())
+ changed = true;
+ }
+ }
+ for (auto entry : entriesToRemove)
+ {
+ table->deferredEntries.remove(entry);
+ }
+ }
+ }
+}
+
LinkedIR linkIR(CodeGenContext* codeGenContext)
{
SLANG_PROFILE;
@@ -1763,50 +2034,61 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
state->target = target;
state->targetReq = targetReq;
-
+ auto& irModules = stateStorage.contextStorage.irModules;
auto sharedContext = state->getSharedContext();
initializeSharedSpecContext(sharedContext, session, nullptr, target, targetReq);
state->irModule = sharedContext->module;
-
// We need to be able to look up IR definitions for any symbols in
// modules that the program depends on (transitively). To
// accelerate lookup, we will create a symbol table for looking
// up IR definitions by their mangled name.
//
-
- List<IRModule*> irModules;
-
- // Link the core modules.
- auto& coreModules = static_cast<Session*>(linkage->getGlobalSession())->coreModules;
- for (auto& m : coreModules)
- irModules.add(m->getIRModule());
+ auto globalSession = static_cast<Session*>(linkage->getGlobalSession());
+ List<IRModule*> builtinModules;
+ for (auto& m : globalSession->coreModules)
+ builtinModules.add(m->getIRModule());
// Link modules in the program.
- program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); });
-
- // Add any modules that were loaded as libraries
- for (IRModule* irModule : irModules)
- {
- insertGlobalValueSymbols(sharedContext, irModule);
- }
+ program->enumerateIRModules(
+ [&](IRModule* module)
+ {
+ if (module->getName() == globalSession->glslModuleName)
+ builtinModules.add(module);
+ else
+ irModules.add(module);
+ });
- // We will also insert the IR global symbols from the IR module
+ // We will also consider the IR global symbols from the IR module
// attached to the `TargetProgram`, since this module is
// responsible for associating layout information to those
// global symbols via decorations.
//
auto irModuleForLayout = targetProgram->getExistingIRModuleForLayout();
- insertGlobalValueSymbols(sharedContext, irModuleForLayout);
+ if (irModuleForLayout)
+ irModules.add(irModuleForLayout);
+
+ Index userModuleCount = irModules.getCount();
+ irModules.addRange(builtinModules);
+ ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount);
+
+ // Check if any user module uses auto-diff, if so we will need to link
+ // additional witnesses and decorations.
+ for (IRModule* irModule : userModules)
+ {
+ if (sharedContext->useAutodiff)
+ break;
+ sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst());
+ }
auto context = state->getContext();
// Combine all of the contents of IRGlobalHashedStringLiterals
{
StringSlicePool pool(StringSlicePool::Style::Empty);
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
findGlobalHashedStringLiterals(irModule, pool);
}
@@ -1875,7 +2157,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// instructions in all the input modules.
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto inst : irModule->getGlobalInsts())
{
@@ -1896,7 +2178,9 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// We need to copy over exported symbols,
// and any global parameters if preserve-params option is set.
if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) ||
- as<IRDifferentiableTypeAnnotation>(inst))
+ sharedContext->useAutodiff &&
+ (as<IRDifferentiableTypeAnnotation>(inst) ||
+ inst->findDecorationImpl(kIROp_AutoDiffBuiltinDecoration) != nullptr))
{
auto cloned = cloneValue(context, inst);
if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
@@ -1907,6 +2191,12 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
}
}
+ // In previous steps, we have skipped cloning the witness table entries, and
+ // registered any used witness table entry keys to context->deferredWitnessTableEntryKeys
+ // for on-demand cloning. Now we will use the deferred keys to clone the witness table
+ // entries that are referenced.
+ cloneUsedWitnessTableEntries(context);
+
// It is possible that metadata has been attached to the input modules
// themselves, which should be copied over to the output module.
//
@@ -1917,7 +2207,7 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
// `[assumedWaveSize(...)]` decoration might require that all specified
// values match exactly).
//
- for (IRModule* irModule : irModules)
+ for (IRModule* irModule : userModules)
{
for (auto decoration : irModule->getModuleInst()->getDecorations())
{
diff --git a/source/slang/slang-ir-strip.cpp b/source/slang/slang-ir-strip.cpp
index e9fbbceb6..6c64bfca2 100644
--- a/source/slang/slang-ir-strip.cpp
+++ b/source/slang/slang-ir-strip.cpp
@@ -51,4 +51,34 @@ void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& optio
_stripFrontEndOnlyInstructionsRec(module->getModuleInst(), options);
}
+void stripImportedWitnessTable(IRModule* module)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ auto inst = globalInst;
+ switch (globalInst->getOp())
+ {
+ case kIROp_Generic:
+ inst = findInnerMostGenericReturnVal(as<IRGeneric>(globalInst));
+ break;
+ case kIROp_WitnessTable:
+ break;
+ default:
+ continue;
+ }
+ if (inst->getOp() != kIROp_WitnessTable)
+ continue;
+ if (!globalInst->findDecoration<IRImportDecoration>())
+ continue;
+ IRInst* nextChild = nullptr;
+ for (auto child = inst->getFirstChild(); child;)
+ {
+ nextChild = child->getNextInst();
+ if (child->getOp() == kIROp_WitnessTable)
+ child->removeAndDeallocate();
+ child = nextChild;
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-strip.h b/source/slang/slang-ir-strip.h
index 9d97669a1..f6f53aeda 100644
--- a/source/slang/slang-ir-strip.h
+++ b/source/slang/slang-ir-strip.h
@@ -13,5 +13,9 @@ struct IRStripOptions
/// Strip out instructions that should only be used by the front-end.
void stripFrontEndOnlyInstructions(IRModule* module, IRStripOptions const& options);
+
+/// Strip witness table entries from imported witness tables.
+void stripImportedWitnessTable(IRModule* module);
+
} // namespace Slang
#pragma once
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index a3cf28a68..dbd6ac099 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2175,4 +2175,14 @@ void legalizeDefUse(IRGlobalValueWithCode* func)
}
}
+UnownedStringSlice getMangledName(IRInst* inst)
+{
+ for (auto decor : inst->getDecorations())
+ {
+ if (auto linkageDecor = as<IRLinkageDecoration>(decor))
+ return linkageDecor->getMangledName();
+ }
+ return UnownedStringSlice();
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 46e0105f5..610524754 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -377,6 +377,8 @@ Int getSpecializationConstantId(IRGlobalParam* param);
void legalizeDefUse(IRGlobalValueWithCode* func);
+UnownedStringSlice getMangledName(IRInst* inst);
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f28f61ffc..fb274c4a0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4538,6 +4538,18 @@ RefPtr<IRModule> IRModule::create(Session* session)
return module;
}
+void IRModule::buildMangledNameToGlobalInstMap()
+{
+ m_mapMangledNameToGlobalInst.clear();
+ for (auto inst : getGlobalInsts())
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ {
+ m_mapMangledNameToGlobalInst[linkageDecor->getMangledName()].add(inst);
+ }
+ }
+}
+
IRDominatorTree* IRModule::findOrCreateDominatorTree(IRGlobalValueWithCode* func)
{
IRAnalysis* analysis = m_mapInstToAnalysis.tryGetValue(func);
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 8f53c9f14..ecf5d1c66 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -2374,6 +2374,15 @@ public:
m_obfuscatedSourceMap = sourceMap;
}
+ ArrayView<IRInst*> findSymbolByMangledName(const ImmutableHashedString& mangledName) const
+ {
+ if (auto list = m_mapMangledNameToGlobalInst.tryGetValue(mangledName))
+ return list->getArrayView();
+ return {};
+ }
+
+ void buildMangledNameToGlobalInstMap();
+
IRDeduplicationContext* getDeduplicationContext() const { return &m_deduplicationContext; }
IRDominatorTree* findDominatorTree(IRGlobalValueWithCode* func)
@@ -2392,6 +2401,9 @@ public:
IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); }
+ Name* getName() const { return m_name; }
+ void setName(Name* name) { m_name = name; }
+
/// Create an empty instruction with the `op` opcode and space for
/// a number of operands given by `operandCount`.
///
@@ -2444,6 +2456,9 @@ private:
///
IRModuleInst* m_moduleInst = nullptr;
+ // The name of the module.
+ Name* m_name = nullptr;
+
/// The memory arena from which all IR instructions (and any associated state) in this module
/// are allocated.
MemoryArena m_memoryArena;
@@ -2459,6 +2474,8 @@ private:
ComPtr<IBoxValue<SourceMap>> m_obfuscatedSourceMap;
Dictionary<IRInst*, IRAnalysis> m_mapInstToAnalysis;
+
+ Dictionary<ImmutableHashedString, List<IRInst*>> m_mapMangledNameToGlobalInst;
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index fbe6d8a84..e5037bf04 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8168,24 +8168,34 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// If the witness table is for a COM interface, always keep it alive.
if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
{
- subBuilder->addPublicDecoration(irWitnessTable);
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
}
- if (parentDecl->findModifier<HLSLExportModifier>())
+ for (auto mod : parentDecl->modifiers)
{
- subBuilder->addHLSLExportDecoration(irWitnessTable);
- subBuilder->addKeepAliveDecoration(irWitnessTable);
+ if (as<HLSLExportModifier>(mod))
+ {
+ subBuilder->addHLSLExportDecoration(irWitnessTable);
+ subBuilder->addKeepAliveDecoration(irWitnessTable);
+ }
+ else if (as<AutoDiffBuiltinAttribute>(mod))
+ {
+ subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
+ }
}
// Make sure that all the entries in the witness table have been filled in,
// including any cases where there are sub-witness-tables for conformances
- Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
- lowerWitnessTable(
- subContext,
- inheritanceDecl->witnessTable,
- irWitnessTable,
- mapASTToIRWitnessTable);
-
+ bool isExplicitExtern = false;
+ if (!isImportedDecl(context, parentDecl, isExplicitExtern))
+ {
+ Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
+ lowerWitnessTable(
+ subContext,
+ inheritanceDecl->witnessTable,
+ irWitnessTable,
+ mapASTToIRWitnessTable);
+ }
irWitnessTable->moveToEnd();
return LoweredValInfo::simple(
@@ -11536,6 +11546,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
RefPtr<IRModule> module = IRModule::create(session);
+ module->setName(translationUnit->getModuleDecl()->getName());
+
IRBuilder builderStorage(module);
IRBuilder* builder = &builderStorage;
@@ -11804,6 +11816,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
stripOptions.stripSourceLocs = false;
stripFrontEndOnlyInstructions(module, stripOptions);
+ stripImportedWitnessTable(module);
+
// Stripping out decorations could leave some dead code behind
// in the module, and in some cases that extra code is also
// undesirable (e.g., the string literals referenced by name-hint
@@ -11847,6 +11861,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
&writer);
}
+ module->buildMangledNameToGlobalInstMap();
+
return module;
}
@@ -11884,7 +11900,7 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
context->irBuilder = builder;
componentType->acceptVisitor(this, nullptr);
-
+ module->buildMangledNameToGlobalInstMap();
return module;
}
@@ -12040,6 +12056,7 @@ struct TypeConformanceIRGenContext
{
builder->addSequentialIDDecoration(witness, conformanceIdOverride);
}
+ module->buildMangledNameToGlobalInstMap();
return module;
}
};
@@ -12507,7 +12524,7 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink)
// Eliminate any dead code
eliminateDeadCode(irModule, options);
}
-
+ irModule->buildMangledNameToGlobalInstMap();
m_irModuleForLayout = irModule;
return irModule;
}
diff --git a/source/slang/slang-serialize-ir.cpp b/source/slang/slang-serialize-ir.cpp
index c6926735a..b8a1183b3 100644
--- a/source/slang/slang-serialize-ir.cpp
+++ b/source/slang/slang-serialize-ir.cpp
@@ -1047,6 +1047,8 @@ Result IRSerialReader::read(
}
}
+ outModule->buildMangledNameToGlobalInstMap();
+
return SLANG_OK;
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index c316974f1..ec90ee418 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -410,6 +410,11 @@ const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name)
return result;
}
+TypeCheckingCache* Session::getTypeCheckingCache()
+{
+ return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get());
+}
+
Session::BuiltinModuleInfo Session::getBuiltinModuleInfo(slang::BuiltinModuleName name)
{
Session::BuiltinModuleInfo result;
@@ -700,6 +705,7 @@ SlangResult Session::_readBuiltinModule(
module->setModuleDecl(moduleDecl);
}
+ srcModule.irModule->setName(module->getNameObj());
module->setIRModule(srcModule.irModule);
// Put in the loaded module map
@@ -803,6 +809,10 @@ Session::createSession(slang::SessionDesc const& inDesc, slang::ISession** outSe
RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage());
+ if (m_typeCheckingCache)
+ linkage->m_typeCheckingCache =
+ new TypeCheckingCache(*static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()));
+
linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode);
Int searchPathCount = desc.searchPathCount;
@@ -1263,9 +1273,6 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
, m_astBuilder(astBuilder)
, m_cmdLineContext(new CommandLineContext())
{
- if (builtinLinkage)
- m_astBuilder->m_cachedNodes = builtinLinkage->getASTBuilder()->m_cachedNodes;
-
getNamePool()->setRootNamePool(session->getRootNamePool());
m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr);
@@ -1297,6 +1304,17 @@ ISlangUnknown* Linkage::getInterface(const Guid& guid)
Linkage::~Linkage()
{
+ // Upstream type checking cache.
+ if (m_typeCheckingCache)
+ {
+ auto globalSession = getSessionImpl();
+ if (!globalSession->m_typeCheckingCache ||
+ globalSession->getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount() <
+ getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount())
+ {
+ globalSession->m_typeCheckingCache = m_typeCheckingCache;
+ }
+ }
destroyTypeCheckingCache();
}
@@ -1318,12 +1336,11 @@ TypeCheckingCache* Linkage::getTypeCheckingCache()
{
m_typeCheckingCache = new TypeCheckingCache();
}
- return m_typeCheckingCache;
+ return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get());
}
void Linkage::destroyTypeCheckingCache()
{
- delete m_typeCheckingCache;
m_typeCheckingCache = nullptr;
}
@@ -4080,6 +4097,8 @@ RefPtr<Module> Linkage::loadModuleFromIRBlobImpl(
loadedModulesList.add(resultModule);
resultModule->setPathInfo(filePathInfo);
+ resultModule->getIRModule()->setName(resultModule->getNameObj());
+
return resultModule;
}
diff --git a/tests/language-feature/static_assert.slang b/tests/language-feature/static_assert.slang
index 55bfa0abb..d7806cf30 100644
--- a/tests/language-feature/static_assert.slang
+++ b/tests/language-feature/static_assert.slang
@@ -1,6 +1,4 @@
//TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain
-//TEST:SIMPLE(filecheck=CHK):-target spirv -stage compute -entry computeMain
//TEST:SIMPLE(filecheck=HLSL):-target hlsl -stage compute -entry computeMain
//TEST:SIMPLE(filecheck=GLSL):-target glsl -stage compute -entry computeMain
//TEST:SIMPLE(filecheck=SPV):-target spirv -stage compute -entry computeMain
@@ -56,14 +54,12 @@ extension MyType<T>
[numthreads(1,1,1)]
void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
{
- //CHK-NOT:error {{.*}} TEST_specialize
- //CHK: error {{.*}} TEST_specialize T_is_int
- //CHK-NOT:error {{.*}} TEST_specialize
+ // CHK-DAG: error {{.*}} TEST_specialize T_is_int
+ // CHK-DAG:{{.*}} TEST_specialize<float>
TEST_specialize<float>();
- //CHK-NOT:error {{.*}} TEST_specialize
- //CHK: error 41400: {{.*}} TEST_specialize T_is_float
- //CHK-NOT:error {{.*}} TEST_specialize
+ // CHK-DAG: error {{.*}} TEST_specialize T_is_float
+ // CHK-DAG:{{.*}} TEST_specialize<int>
TEST_specialize<int>();
//HLSL: error {{.*}} TEST_target_switch all
diff --git a/tools/slang-unit-test/unit-test-compile-benchmark.cpp b/tools/slang-unit-test/unit-test-compile-benchmark.cpp
new file mode 100644
index 000000000..e38edc6ad
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-compile-benchmark.cpp
@@ -0,0 +1,113 @@
+// unit-test-glsl-compile.cpp
+
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+#include "../../tools/platform/performance-counter.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+#include "unit-test/slang-unit-test.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+using namespace Slang;
+
+// Test the compilation API for cross-compiling glsl source to SPIRV.
+
+SLANG_UNIT_TEST(compileBenchmark)
+{
+ const char* userSourceBody = R"(
+// shader.slang
+
+struct PushConstantCompute
+{
+ uint64_t bufferAddress;
+ uint numVertices;
+};
+
+struct Vertex
+{
+ float3 position;
+};
+
+
+[[vk::push_constant]]
+ConstantBuffer<PushConstantCompute> pushConst;
+
+[shader("compute")]
+[numthreads(256, 1, 1)]
+void main(uint3 threadIdx : SV_DispatchThreadID)
+{
+ uint index = threadIdx.x;
+
+ if(index >= pushConst.numVertices)
+ return;
+
+ Vertex* vertices = (Vertex*)pushConst.bufferAddress;
+
+ float angle = (index + 1) * 2.3f;
+
+ float3 vertex = vertices[index].position;
+
+ float cosAngle = cos(angle);
+ float sinAngle = sin(angle);
+ float3x3 rotationMatrix = float3x3(
+ cosAngle, -sinAngle, 0.0,
+ sinAngle, cosAngle, 0.0,
+ 0.0, 0.0, 1.0
+ );
+
+ float3 rotatedVertex = mul(rotationMatrix, vertex);
+
+ vertices[index].position = rotatedVertex;
+}
+ )";
+ ComPtr<slang::IGlobalSession> globalSession;
+ SlangGlobalSessionDesc globalDesc = {};
+ globalDesc.enableGLSL = false;
+ SLANG_CHECK(slang_createGlobalSession2(&globalDesc, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_SPIRV;
+ targetDesc.profile = globalSession->findProfile("spirv_1_5");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+
+ auto start = platform::PerformanceCounter::now();
+ for (int pass = 0; pass < 100; pass++)
+ {
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint(
+ "main",
+ SLANG_STAGE_VERTEX,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
+
+ slang::IComponentType* componentTypes[2] = {module, entryPoint.get()};
+ ComPtr<slang::IComponentType> composedProgram;
+ session->createCompositeComponentType(
+ componentTypes,
+ 2,
+ composedProgram.writeRef(),
+ diagnosticBlob.writeRef());
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+
+ ComPtr<slang::IBlob> code;
+ linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+ }
+ auto time = platform::PerformanceCounter::getElapsedTimeInSeconds(start);
+ getTestReporter()->addExecutionTime(time);
+}