diff options
| author | Yong He <yonghe@outlook.com> | 2025-02-12 14:58:00 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-12 14:58:00 -0800 |
| commit | 74852ceb6b3bcc018042aba3e30933b7b6fc09ef (patch) | |
| tree | b77dd62f0e74510fae0d2af8b7afa260ce9d49b8 | |
| parent | 3f102afe1038882f336dc052a9954811150fa700 (diff) | |
Allow LHS of `where` to be any type. (#6333)
* Allow LHS of `where` to be any type.
* Register free-form extensions when loading precompiled module.
* Fix test.
* Fix.
* Fix `as<IRType>`.
* try fix precompiled module test.
| -rw-r--r-- | source/slang/hlsl.meta.slang | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 133 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-hlsl-legalize.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-use-uninitialized-values.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-serialize-container.cpp | 16 | ||||
| -rw-r--r-- | tests/bugs/gh-6331.slang | 91 | ||||
| -rw-r--r-- | tools/gfx-unit-test/precompiled-module-cache.cpp | 4 |
11 files changed, 222 insertions, 64 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a10e747c0..81b28b30a 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -524,6 +524,13 @@ interface ITexelElement __init(Element x); } +extension<T:__BuiltinArithmeticType> T : ITexelElement +{ + typealias Element = T; + static const int elementCount = 1; + __intrinsic_op(0) __init(Element x); +} + ${{{ // Scalar types that can be used as texel element. const char* texeElementScalarTypes[] = { @@ -539,12 +546,6 @@ const char* texeElementScalarTypes[] = { for (auto elementType : texeElementScalarTypes) { }}} -extension $(elementType) : ITexelElement -{ - typealias Element = $(elementType); - static const int elementCount = 1; - __intrinsic_op(0) __init(Element x); -} extension<int N> vector<$(elementType), N> : ITexelElement { typealias Element = $(elementType); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 76074f551..1ef5b1cec 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -9789,6 +9789,46 @@ void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl) } } +bool getExtensionTargetDeclList( + ASTBuilder* astBuilder, + DeclRefType* targetDeclRefType, + ExtensionDecl* extDecl, + ShortList<AggTypeDecl*>& targetDecls) +{ + if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) + { + auto aggTypeDecl = aggTypeDeclRef.getDecl(); + + targetDecls.add(aggTypeDecl); + return true; + } + + auto genericParamDeclRef = targetDeclRefType->getDeclRef().as<GenericTypeParamDeclBase>(); + if (!genericParamDeclRef) + return false; + + auto genericParent = as<GenericDecl>(genericParamDeclRef.getParent().getDecl()); + if (!genericParent) + return false; + + if (genericParent != extDecl->parentDecl) + return false; + + for (auto member : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericParent)) + { + if (getSub(astBuilder, member) == targetDeclRefType) + { + auto baseType = getSup(astBuilder, member); + if (auto baseTypeDecl = isDeclRefTypeOf<AggTypeDecl>(baseType)) + { + targetDecls.add(baseTypeDecl.getDecl()); + } + } + } + return targetDecls.getCount() != 0; +} + + void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl) { if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) @@ -11582,8 +11622,8 @@ void checkDerivativeAttributeImpl( auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef); // If the function is a member function, we need to check that the - // `this` type matches the expected type. This will ensure that after lowering to - // IR, the two functions are compatible. + // `this` type matches the expected type. This will ensure that after lowering + // to IR, the two functions are compatible. // if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType)) { @@ -11971,8 +12011,9 @@ void checkDerivativeOfAttributeImpl( if (as<ErrorType>(resolved->type.type)) { - // If we can't resolve a type, something went wrong. If we're working with a generic - // decl, the most likely cause is a failure of generic argument inference. + // If we can't resolve a type, something went wrong. If we're working with a + // generic decl, the most likely cause is a failure of generic argument + // inference. // visitor->getSink()->diagnose( derivativeOfAttr, @@ -12284,8 +12325,8 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers( // Find the base type's members first for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>()) { - // For base types, we need to pick their parameters of the constructor to the derived type's - // constructor + // For base types, we need to pick their parameters of the constructor to the derived + // type's constructor if (auto baseTypeDeclRef = isDeclRefTypeOf<StructDecl>(inheritanceMember->base.type)) { // We should only find the member initialization constructor because it is the @@ -12294,15 +12335,15 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers( baseTypeDeclRef.getDecl(), ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); - // The constructor has to have higher or equal visibility level than the struct itself, - // otherwise, it's not accessible so we will not pick up. + // The constructor has to have higher or equal visibility level than the struct + // itself, otherwise, it's not accessible so we will not pick up. if (ctor && getDeclVisibility(ctor) >= ctorVisibility) { for (ParamDecl* param : ctor->getParameters()) { - // Because the parameters in the ctor must have the higher or equal visibility - // than the ctor itself, we don't need to check the visibility level of the - // parameter. + // Because the parameters in the ctor must have the higher or equal + // visibility than the ctor itself, we don't need to check the visibility + // level of the parameter. resultMembers.add(param); } } @@ -12342,10 +12383,9 @@ static Expr* _getParamDefaultValue(SemanticsVisitor* visitor, VarDeclBase* varDe bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* structDecl) { - // If a type or its base type already defines any explicit constructors, do not synthesize any - // constructors. - // See - // https://github.com/shader-slang/spec/blob/main/proposals/004-initialization.md#inheritance-initialization + // If a type or its base type already defines any explicit constructors, do not synthesize + // any constructors. see: + // https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#inheritance-initialization if (_hasExplicitConstructor(structDecl, true)) return false; @@ -12397,9 +12437,9 @@ bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* struct ctorParam->loc = ctor->loc; ctor->members.add(ctorParam); - // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` modifiers do - // not matter in this case since member-wise ctor is always differentiable or "treat as - // differentiable". + // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` + // modifiers do not matter in this case since member-wise ctor is always differentiable + // or "treat as differentiable". if (!isTypeDifferentiable(member->getType()) || member->hasModifier<NoDiffModifier>()) { auto noDiffMod = m_astBuilder->create<NoDiffModifier>(); @@ -12559,7 +12599,8 @@ void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl) totalWidth += int(thisFieldWidth); groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm}); } - // If the struct ended with a bitpacked member, then make sure we don't forget the last group + // If the struct ended with a bitpacked member, then make sure we don't forget the last + // group dispatchSomeBitPackedMembers(); } @@ -12630,8 +12671,8 @@ static void _propagateRequirement( if (!isAnyInvalid && resultCaps.isInvalid()) { // If joining the referenced decl's requirements results an invalid capability set, - // then the decl is using things that require conflicting set of capabilities, and we should - // diagnose an error. + // then the decl is using things that require conflicting set of capabilities, and we + // should diagnose an error. if (referencedDecl && decl) { maybeDiagnose( @@ -12736,17 +12777,17 @@ struct CapabilityDeclReferenceVisitor // `calling_functions_targets`: // ``` default_target = calling_functions_targets-{other_case_targets} ``` // - // * `calling_functions_capability` = `requirement attribute` of the calling function; - // if missing + // * `calling_functions_capability` = `requirement attribute` of the calling + // function; if missing // we can assume it is `any_target` // - // * `{other_case_targets}` = set of all capabilities all `case` statments target inside - // the `__target_switch` + // * `{other_case_targets}` = set of all capabilities all `case` statments target + // inside the `__target_switch` - // If we do not handle `default:`, the codegen will fail when trying to find a specific - // codegen target not handled explicitly by a `case` statment. - // We must also ensure the `default` case is last so we have priority to hit `case` - // statments and can preprocess `case` statments before the `default` case. + // If we do not handle `default:`, the codegen will fail when trying to find a + // specific codegen target not handled explicitly by a `case` statment. We must also + // ensure the `default` case is last so we have priority to hit `case` statments and + // can preprocess `case` statments before the `default` case. CapabilitySet targetCap; if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid) @@ -12901,7 +12942,8 @@ CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* dec // For every existing target, we want to join their requirements together. // If the the parent defines additional targets, we want to add them to the disjunction set. // For example: - // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); } + // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void + // foo(); } // The requirement for `foo` should be glsl+glsl_ext_1 | spirv. // CapabilitySet declaredCaps; @@ -12990,8 +13032,8 @@ static inline void _dispatchCapabilitiesVisitorOfFunctionDecl( void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) { - // If the function is an entrypoint and specifies a target stage, add the capabilities to our - // function capabilities. + // If the function is an entrypoint and specifies a target stage, add the capabilities to + // our function capabilities. _dispatchCapabilitiesVisitorOfFunctionDecl( this, funcDecl, @@ -13012,8 +13054,8 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun auto vis = getDeclVisibility(funcDecl); - // If 0 capabilities were annotated on a function, capabilities are inferred from the function - // body + // If 0 capabilities were annotated on a function, capabilities are inferred from the + // function body if (declaredCaps.isEmpty()) { declaredCaps = funcDecl->inferredCapabilityRequirements; @@ -13133,7 +13175,8 @@ DeclVisibility getDeclVisibility(Decl* decl) : parentModule->defaultVisibility; } - // Members of other agg type decls will have their default visibility capped to the parents'. + // Members of other agg type decls will have their default visibility capped to the + // parents'. if (as<NamespaceDecl>(decl)) { return DeclVisibility::Public; @@ -13345,15 +13388,16 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( // There are two causes for why type checking failed on failedAvailableSet. // The first scenario is that failedAvailableSet defines a set of capabilities on a - // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have - // a function: + // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we + // have a function: // [require(hlsl)] // <-- failedAvailableSet // [require(cpp)] // void caller() // { // printf(); // assume this is defined for (cpp | cuda). // } - // In this case we should diagnose error reporting printf isn't defined on a required target. + // In this case we should diagnose error reporting printf isn't defined on a required + // target. // // Now, we detect if we are case 1. @@ -13370,8 +13414,8 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( decl, outFailedAtom); - // Anything defined on a non-failed target atom may be the culprit to why we fail having - // a target capability. Print out all possible culprits. + // Anything defined on a non-failed target atom may be the culprit to why we fail + // having a target capability. Print out all possible culprits. CapabilityAtomSet failedAtomSet; failedAtomSet.add((UInt)outFailedAtom); CapabilityAtomSet targetsNotUsedSet; @@ -13395,9 +13439,12 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( } } - //// The second scenario is when the callee is using a capability that is not provided by the - /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / useD(); - ///// require capability (hlsl,d) / } / In this case we should report that useD() is using a + //// The second scenario is when the callee is using a capability that is not provided by + /// the + /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / + /// useD(); + ///// require capability (hlsl,d) / } / In this case we should report that useD() is + /// using a /// capability that is not declared by caller. //// diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 350362a0e..59290f8ad 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3066,4 +3066,17 @@ bool resolveStageOfProfileWithEntryPoint( const List<RefPtr<TargetRequest>>& targets, FuncDecl* entryPointFuncDecl, DiagnosticSink* sink); + +// For an extensions decl, collect a list of decls on which the extension might be applying to. +// For example, if we see a `extension Foo`, return a `Decl*` that represents `struct Foo`. +// In the case of free-form generic extensions i.e. `extension<T:IFoo> T : IBar`, return `IFoo`. +// These are the decls that we need to register the extension with in +// `mapTypeToCandidateExtensions`. +// Returns true when any base decls are found. +bool getExtensionTargetDeclList( + ASTBuilder* astBuilder, + DeclRefType* targetDeclRefType, + ExtensionDecl* extDeclRef, + ShortList<AggTypeDecl*>& targetDecls); + } // namespace Slang diff --git a/source/slang/slang-ir-hlsl-legalize.cpp b/source/slang/slang-ir-hlsl-legalize.cpp index ec2419985..7116d635e 100644 --- a/source/slang/slang-ir-hlsl-legalize.cpp +++ b/source/slang/slang-ir-hlsl-legalize.cpp @@ -36,7 +36,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in continue; auto forceStructArg = arg->getOperand(0); auto forceStructBaseType = - as<IRType>(forceStructArg->getDataType()->getOperand(0)); + (IRType*)(forceStructArg->getDataType()->getOperand(0)); IRBuilder builder(call); if (forceStructBaseType->getOp() == kIROp_StructType) { diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9c3892c0e..dbefa68c7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1143,7 +1143,7 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration IRUse pairType; IR_LEAF_ISA(MixedDifferentialInstDecoration) - IRType* getPairType() { return as<IRType>(getOperand(0)); } + IRType* getPairType() { return (IRType*)(getOperand(0)); } }; struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 364e58c48..53bbbda9e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -614,8 +614,8 @@ IRWitnessTable* cloneWitnessTableImpl( IRWitnessTable* clonedTable = dstTable; if (!clonedTable) { - auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType())); - auto clonedSubType = cloneType(context, as<IRType>(originalTable->getConcreteType())); + auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); } cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp index 51de2117c..71aae5923 100644 --- a/source/slang/slang-ir-use-uninitialized-values.cpp +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -186,7 +186,7 @@ static bool canIgnoreType(IRType* type, IRType* upper) if (auto spec = as<IRSpecialize>(type)) { IRInst* inner = getResolvedInstForDecorations(spec); - IRType* innerType = as<IRType>(inner); + IRType* innerType = (IRType*)(inner); return canIgnoreType(innerType, upper); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 06cdf430b..ed8a52b9e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9453,10 +9453,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // should handle propgation of value-size information from constraints // back to generic parameters? // - if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type)) + if (auto genParamDeclRef = + isDeclRefTypeOf<GenericTypeParamDeclBase>(constraintDecl->sub.type)) { - auto typeParamDeclVal = - subContext->findLoweredDecl(declRefType->getDeclRef().getDecl()); + auto typeParamDeclVal = subContext->findLoweredDecl(genParamDeclRef.getDecl()); SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val); subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType); } diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index e5775f526..b5121d373 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -5,6 +5,7 @@ #include "../core/slang-math.h" #include "../core/slang-stream.h" #include "../core/slang-text-io.h" +#include "slang-check-impl.h" #include "slang-compiler.h" #include "slang-mangled-lexer.h" #include "slang-parser.h" @@ -813,15 +814,16 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( if (auto targetDeclRefType = as<DeclRefType>(extensionDecl->targetType)) { - // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = - targetDeclRefType->getDeclRef() - .as<AggTypeDecl>()) + ShortList<AggTypeDecl*> baseDecls; + getExtensionTargetDeclList( + astBuilder, + targetDeclRefType, + extensionDecl, + baseDecls); + for (auto baseDecl : baseDecls) { - auto aggTypeDecl = aggTypeDeclRef.getDecl(); - _getCandidateExtensionList( - aggTypeDecl, + baseDecl, moduleDecl->mapTypeToCandidateExtensions) .add(extensionDecl); } diff --git a/tests/bugs/gh-6331.slang b/tests/bugs/gh-6331.slang new file mode 100644 index 000000000..c3f786cfc --- /dev/null +++ b/tests/bugs/gh-6331.slang @@ -0,0 +1,91 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +// CHECK: OpEntryPoint + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +[ForceInline] +__generic<T : __BuiltinFloatingPointType> +public T accumulateSample(uint64_t pSampleIdx, T pOldValue, T pNewValue) { + T lTmp = isfinite(pOldValue) ? pOldValue : T(0); + return lTmp + (pNewValue - lTmp) / T(pSampleIdx + 1); +} +[ForceInline] +__generic<T : __BuiltinFloatingPointType, let N : int> +public vector<T, N> accumulateSample(uint64_t pSampleIdx, vector<T, N> pOldValue, vector<T, N> pNewValue) { + vector<T, N> lTmp = select(isfinite(pOldValue), pOldValue, T(0)); + return lTmp + (pNewValue - lTmp) / vector<T, N>(pSampleIdx + 1); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +public struct RenderBuffer2D<T : ITexelElement> { + public Texture2D<T> handle; + public uint16_t2 offset; + public uint16_t2 size; + + public __init() { + handle = {}; + offset = {}; + size = {}; + } + + [ForceInline] public bool isValid() { return size[0] > 0u && size[1] > 0u; } + [ForceInline] public uint64_t getPixelCount() { return (uint64_t)(size[0]) * (uint64_t)(size[1]); } +} + +public struct RWRenderBuffer2D<T : ITexelElement> { + public RWTexture2D<T> handle; + public uint16_t2 offset; + public uint16_t2 size; + + public __init() { + handle = {}; + offset = {}; + size = {}; + } + + [ForceInline] public bool isValid() { return size[0] > 0u && size[1] > 0u; } + [ForceInline] public uint64_t getPixelCount() { return (uint64_t)(size[0]) * (uint64_t)(size[1]); } +} + +__generic<T : __BuiltinFloatingPointType> +public extension RWRenderBuffer2D<T> where T : ITexelElement { + [ForceInline] + public void accumulate(uint32_t2 pDestIdx, uint64_t pSampleIdx, T pSampleValue) { + uint32_t2 lIndex = offset + pDestIdx; + handle[lIndex] = accumulateSample(pSampleIdx, handle[lIndex], pSampleValue); + } +} +__generic<T : __BuiltinFloatingPointType, let N : int> +public extension RWRenderBuffer2D<vector<T, N>> where vector<T,N>:ITexelElement { + [ForceInline] + public void accumulate(uint32_t2 pDestIdx, uint64_t pSampleIdx, vector<T, N> pSampleValue) { + uint32_t2 lIndex = offset + pDestIdx; + handle[lIndex] = accumulateSample(pSampleIdx, handle[lIndex], pSampleValue); + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Global Parameters +public uniform uint32_t gSampleIdx; +public uniform RenderBuffer2D<float4> gInput; +public uniform RWRenderBuffer2D<float4> gOutput; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +[shader("compute")] +[numthreads(16, 16, 1)] +void refineImage(uint3 pDispatchThreadIdx: SV_DispatchThreadID) { + const uint2 lPixelCoordinates = { pDispatchThreadIdx.x, pDispatchThreadIdx.y }; + + // Some sanity checks. + if (!gInput.isValid() || any(lPixelCoordinates >= gInput.size)) { + return; + } + if (!gOutput.isValid() || any(lPixelCoordinates >= gOutput.size)) { + return; + } + + gOutput.accumulate(lPixelCoordinates, gSampleIdx, gInput.handle[lPixelCoordinates]); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/gfx-unit-test/precompiled-module-cache.cpp b/tools/gfx-unit-test/precompiled-module-cache.cpp index 8c22a7c84..778c68a89 100644 --- a/tools/gfx-unit-test/precompiled-module-cache.cpp +++ b/tools/gfx-unit-test/precompiled-module-cache.cpp @@ -7,6 +7,7 @@ #include "slang-gfx.h" #include "unit-test/slang-unit-test.h" +#include <mutex> using namespace gfx; namespace gfx_test @@ -15,6 +16,9 @@ namespace gfx_test Slang::ComPtr<slang::ISession> createSession(gfx::IDevice* device, ISlangFileSystemExt* fileSys) { + static std::mutex m; + std::lock_guard<std ::mutex> lock(m); + Slang::ComPtr<slang::ISession> slangSession; device->getSlangSession(slangSession.writeRef()); slang::SessionDesc sessionDesc = {}; |
