diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-10 11:50:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-10 18:50:30 +0000 |
| commit | c5607e9d68e9082ada9441f1949937f6b16d5c7f (patch) | |
| tree | d7d947af4a8d29cb8d3631d9703f9ccfe8d26735 | |
| parent | ec42c4a20facbcae441cd172bfd607614e761907 (diff) | |
Fix crash when compiling specialized generic entrypoint containing a static const decl. (#8392)
Closes #8184.
We fixed three issues with this regression test:
1. After generating IR for a `SpecializeComponentType`, we should also
strip the frontend
decorations from the IR so there is no HighLevelDeclDecoration that will
go into the backend.
2. When lowering a static const inside a generic function, we should not
give the static const
a linkage, because it won't such constant will not appear in global
scope. Trying to give it a
linkage decoration will lead to the parent generic (for the function) to
have two duplicate
Export/Import decorations with different mangle names, and confuses the
linker.
3. Make sure internal exceptions does not leak through
`IComponentType::getEntryPointCode`/`getTargetCode`.
| -rw-r--r-- | source/slang/slang-linkable.cpp | 98 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-target-program.cpp | 26 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-gh8184.cpp | 120 |
4 files changed, 208 insertions, 49 deletions
diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp index 0eef62742..7a307cbc3 100644 --- a/source/slang/slang-linkable.cpp +++ b/source/slang/slang-linkable.cpp @@ -657,57 +657,75 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD { return artifact.get(); } - - // If the user hasn't specified any entry points, then we should - // discover all entrypoints that are defined in linked modules, and - // include all of them in the compile. - // - if (getEntryPointCount() == 0) + try { - List<Module*> modules; - this->enumerateModules([&](Module* module) { modules.add(module); }); - List<RefPtr<ComponentType>> components; - components.add(this); - bool entryPointsDiscovered = false; - for (auto module : modules) + // If the user hasn't specified any entry points, then we should + // discover all entrypoints that are defined in linked modules, and + // include all of them in the compile. + // + if (getEntryPointCount() == 0) { - for (auto entryPoint : module->getEntryPoints()) + List<Module*> modules; + this->enumerateModules([&](Module* module) { modules.add(module); }); + List<RefPtr<ComponentType>> components; + components.add(this); + bool entryPointsDiscovered = false; + for (auto module : modules) { - components.add(entryPoint); - entryPointsDiscovered = true; + for (auto entryPoint : module->getEntryPoints()) + { + components.add(entryPoint); + entryPointsDiscovered = true; + } } - } - // If any entry points were discovered, then we should emit the program with entrypoints - // linked. - if (entryPointsDiscovered) - { - RefPtr<CompositeComponentType> composite = - new CompositeComponentType(linkage, components); - ComPtr<IComponentType> linkedComponentType; - SLANG_RETURN_NULL_ON_FAIL( - composite->link(linkedComponentType.writeRef(), outDiagnostics)); - auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get()) - ->getTargetArtifact(targetIndex, outDiagnostics); - if (targetArtifact) + // If any entry points were discovered, then we should emit the program with entrypoints + // linked. + if (entryPointsDiscovered) { - m_targetArtifacts[targetIndex] = targetArtifact; + RefPtr<CompositeComponentType> composite = + new CompositeComponentType(linkage, components); + ComPtr<IComponentType> linkedComponentType; + SLANG_RETURN_NULL_ON_FAIL( + composite->link(linkedComponentType.writeRef(), outDiagnostics)); + auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get()) + ->getTargetArtifact(targetIndex, outDiagnostics); + if (targetArtifact) + { + m_targetArtifacts[targetIndex] = targetArtifact; + } + return targetArtifact; } - return targetArtifact; } - } - auto target = linkage->targets[targetIndex]; - auto targetProgram = getTargetProgram(target); + auto target = linkage->targets[targetIndex]; + auto targetProgram = getTargetProgram(target); - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - IArtifact* targetArtifact = targetProgram->getOrCreateWholeProgramResult(&sink); - sink.getBlobIfNeeded(outDiagnostics); - m_targetArtifacts[targetIndex] = ComPtr<IArtifact>(targetArtifact); - return targetArtifact; + IArtifact* targetArtifact = targetProgram->getOrCreateWholeProgramResult(&sink); + sink.getBlobIfNeeded(outDiagnostics); + m_targetArtifacts[targetIndex] = ComPtr<IArtifact>(targetArtifact); + return targetArtifact; + } + catch (const Exception& e) + { + if (outDiagnostics && !*outDiagnostics) + { + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + sink.diagnose( + SourceLoc(), + Diagnostics::compilationAbortedDueToException, + typeid(e).name(), + e.Message); + sink.getBlobIfNeeded(outDiagnostics); + } + return nullptr; + } } SLANG_NO_THROW SlangResult SLANG_MCALL diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4df778ee6..3fbf0bda9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8896,6 +8896,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRGeneric* outerGeneric = nullptr; + bool needLinkage = true; + // If we are static, then we need to insert the declaration before the parent. // This tries to match the behavior of previous `lowerFunctionStaticConstVarDecl` // functionality @@ -8906,6 +8908,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // be the global scope, but it might be an outer // generic if we are lowering a generic function. subBuilder->setInsertBefore(subBuilder->getFunc()); + + // static values inside a function does not need a linkage. + // trying to insert a linkage decoration to a static constant defined + // inside a generic function can lead to errorneous IR. + needLinkage = false; } else if (!isFunctionVarDecl(decl)) { @@ -8960,8 +8967,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // All of the attributes/decorations we can attach // belong on the IR constant node. // - - addLinkageDecoration(context, irConstant, decl); + if (needLinkage) + addLinkageDecoration(context, irConstant, decl); addNameHint(context, irConstant, decl); addVarDecorations(context, irConstant, decl); @@ -12034,6 +12041,8 @@ static void lowerProgramEntryPointToIR( existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); } + + stripFrontEndOnlyInstructions(builder->getModule(), IRStripOptions()); } /// Ensure that `decl` and all relevant declarations under it get emitted. diff --git a/source/slang/slang-target-program.cpp b/source/slang/slang-target-program.cpp index ffb859b55..d7e7a73bb 100644 --- a/source/slang/slang-target-program.cpp +++ b/source/slang/slang-target-program.cpp @@ -98,16 +98,28 @@ IArtifact* TargetProgram::getOrCreateEntryPointResult(Int entryPointIndex, Diagn if (IArtifact* artifact = m_entryPointResults[entryPointIndex]) return artifact; - // If we haven't yet computed a layout for this target - // program, we need to make sure that is done before - // code generation. - // - if (!getOrCreateIRModuleForLayout(sink)) + try + { + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) + { + return nullptr; + } + + return _createEntryPointResult(entryPointIndex, sink); + } + catch (const Exception& e) { + sink->diagnose( + SourceLoc(), + Diagnostics::compilationAbortedDueToException, + typeid(e).name(), + e.Message); return nullptr; } - - return _createEntryPointResult(entryPointIndex, sink); } } // namespace Slang diff --git a/tools/slang-unit-test/unit-test-gh8184.cpp b/tools/slang-unit-test/unit-test-gh8184.cpp new file mode 100644 index 000000000..59fd24ba3 --- /dev/null +++ b/tools/slang-unit-test/unit-test-gh8184.cpp @@ -0,0 +1,120 @@ +// unit-test-gh8184.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// A regression test for github issue 8184. +// +// We fixed three issues with this regression test: +// 1. After generating IR for a SpecializeComponentType, we should also strip the frontend +// decorations from the IR so there is no HighLevelDeclDecoration that will go into the backend. +// 2. When lowering a static const inside a generic function, we should not give the static const +// a linkage, because it won't such constant will not appear in global scope. Trying to give it a +// linkage decoration will lead to the parent generic (for the function) to have two duplicate +// Export/Import decorations with different mangle names, and confuses the linker. +// 3. Make sure internal exceptions does not leak through +// IComponentType::getEntryPointCode/getTargetCode. +// +SLANG_UNIT_TEST(gh8184) +{ + ComPtr<slang::IGlobalSession> globalSession; + { + SlangGlobalSessionDesc globalDesc = {}; + SLANG_CHECK_ABORT(createGlobalSession(&globalDesc, globalSession.writeRef()) == SLANG_OK); + } + + ComPtr<slang::ISession> session; + { + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_WGSL; + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targets = &targetDesc; + sessionDesc.targetCount = 1; + sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + + SLANG_CHECK_ABORT( + globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + } + + const char* shaderCode = R"SLANG( + interface Transformation + { + float3 apply(float3 coord); + static const uint32_t kParamCount; + } + + struct T1 : Transformation { + float3 apply(float3 coord) { return float3(0, 0, 0); } + static const uint32_t kParamCount = 2; + }; + + struct T2 : Transformation { + float3 apply(float3 coord) { return float3(0, 0, 0); } + static const uint32_t kParamCount = 4; + }; + + [shader("compute")] + [numthreads(1, 1, 1)] + void XYPass<T>() + where T : Transformation + { + static const uint32_t kParamCount = T::kParamCount; + } + )SLANG"; + + Slang::ComPtr<slang::IModule> module; + Slang::ComPtr<slang::IBlob> diagnostics; + { + const char* moduleName = "bugrepro"; + const char* virtualPath = "bugrepro.slang"; + module = session->loadModuleFromSourceString( + moduleName, + virtualPath, + shaderCode, + diagnostics.writeRef()); + SLANG_CHECK_ABORT(module != nullptr); + } + + Slang::ComPtr<slang::IEntryPoint> entryPoint; + SLANG_CHECK_ABORT(module->findEntryPointByName("XYPass", entryPoint.writeRef()) == SLANG_OK); + + Slang::ComPtr<slang::IComponentType> specializedEntryPoint; + { + slang::ProgramLayout* programLayout = module->getLayout(); + SLANG_CHECK_ABORT(programLayout != nullptr); + auto* t1Type = programLayout->findTypeByName("T1"); + SLANG_CHECK_ABORT(t1Type != nullptr); + + slang::SpecializationArg arg = {}; + arg.kind = slang::SpecializationArg::Kind::Type; + arg.type = t1Type; + + SLANG_CHECK_ABORT( + entryPoint + ->specialize(&arg, 1, specializedEntryPoint.writeRef(), diagnostics.writeRef()) == + SLANG_OK); + } + + Slang::ComPtr<slang::IComponentType> program; + { + slang::IComponentType* components[] = {module.get(), specializedEntryPoint.get()}; + SLANG_CHECK_ABORT( + session->createCompositeComponentType(components, 2, program.writeRef()) == SLANG_OK); + } + + Slang::ComPtr<slang::IComponentType> linked; + SLANG_CHECK_ABORT(program->link(linked.writeRef(), diagnostics.writeRef()) == SLANG_OK); + + Slang::ComPtr<slang::IBlob> code; + SLANG_CHECK( + linked->getEntryPointCode(0, 0, code.writeRef(), diagnostics.writeRef()) == SLANG_OK); +} |
