summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAnders Leino <aleino@nvidia.com>2025-03-05 08:16:29 +0200
committerGitHub <noreply@github.com>2025-03-05 06:16:29 +0000
commit5248a0254a48382d06ecb190c9f87c0ab62ff534 (patch)
tree4bc2abc4d4394083e0a45732226f34232756481b
parent6f56b473f4ab49dd6ec111b56cfc1701196f9c8c (diff)
Fix codegen bug when targeting PTX with new API (#6506)
* Add cuda codegen bug repro This just compiles tests/compute/simlpe.slang for PTX with the new compilation API, in order to reproduce a code generation bug. * Detect entrypoint more robustly when applying ConstRef hack during lowring For shaders like tests/compute/simple.slang, which have a 'numthreads' attribute but no 'shader' attribute, the old compile request API would add an EntryPointAttribute to the AST node of the entry point. However, the new API doesn't, and so a certain ConstRef hack doesn't get applied when using the new API, leading to subsequent code generation issues. This patch also checks for a 'numthreads' attribute when deciding whether to apply the ConstRef hack. This closes issue #6507 and helps to resolve issue #4760. * Add expected failure list for GitHub runners Our GitHub runners don't have the CUDA toolkits installed, so they can't run all tests.
-rw-r--r--.github/workflows/ci.yml6
-rw-r--r--source/slang/slang-lower-to-ir.cpp3
-rw-r--r--tests/expected-failure-github-runner.txt1
-rw-r--r--tools/slang-unit-test/unit-test-find-check-entrypoint.cpp65
4 files changed, 72 insertions, 3 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 66d33bc08..27e255372 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -176,14 +176,16 @@ jobs:
-category ${{ matrix.test-category }} \
-api all-dx12 \
-expected-failure-list tests/expected-failure-github.txt \
- -expected-failure-list tests/expected-failure-record-replay-tests.txt
+ -expected-failure-list tests/expected-failure-record-replay-tests.txt \
+ -expected-failure-list tests/expected-failure-github-runner.txt
else
"$bin_dir/slang-test" \
-use-test-server \
-category ${{ matrix.test-category }} \
-api all-dx12 \
-expected-failure-list tests/expected-failure-github.txt \
- -expected-failure-list tests/expected-failure-record-replay-tests.txt
+ -expected-failure-list tests/expected-failure-record-replay-tests.txt \
+ -expected-failure-list tests/expected-failure-github-runner.txt
fi
- name: Run Slang examples
if: steps.filter.outputs.should-run == 'true' && matrix.platform != 'wasm' && matrix.full-gpu-tests
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e5ca77634..4d692b727 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3214,7 +3214,8 @@ void collectParameterLists(
// For now we will rely on a follow up pass to remove unnecessary temporary variables if
// we can determine that they are never actually writtten to by the user.
//
- bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>();
+ bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>() ||
+ declRef.getDecl()->hasModifier<NumThreadsAttribute>();
// Don't collect parameters from the outer scope if
// we are in a `static` context.
diff --git a/tests/expected-failure-github-runner.txt b/tests/expected-failure-github-runner.txt
new file mode 100644
index 000000000..1da3a9669
--- /dev/null
+++ b/tests/expected-failure-github-runner.txt
@@ -0,0 +1 @@
+slang-unit-test-tool/cudaCodeGenBug.internal
diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
index 8ecab9671..75da9aaf0 100644
--- a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
+++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
@@ -71,3 +71,68 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
SLANG_CHECK(code != nullptr);
SLANG_CHECK(code->getBufferSize() != 0);
}
+
+// This test reproduces issue #6507, where it was noticed that compilation of
+// tests/compute/simple.slang for PTX target generates invalid code.
+// TODO: Remove this when issue #4760 is resolved, because at that point
+// tests/compute/simple.slang should cover the same issue.
+SLANG_UNIT_TEST(cudaCodeGenBug)
+{
+ // Source for a module that contains an undecorated entrypoint.
+ const char* userSourceBody = R"(
+ RWStructuredBuffer<float> outputBuffer;
+
+ [numthreads(4, 1, 1)]
+ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+ {
+ outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x);
+ }
+ )";
+
+ auto moduleName = "moduleG" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_PTX;
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint(
+ "computeMain",
+ SLANG_STAGE_COMPUTE,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(entryPoint != nullptr);
+
+ ComPtr<slang::IComponentType> compositeProgram;
+ slang::IComponentType* components[] = {module, entryPoint.get()};
+ session->createCompositeComponentType(
+ components,
+ 2,
+ compositeProgram.writeRef(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(compositeProgram != nullptr);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(linkedProgram != nullptr);
+
+ ComPtr<slang::IBlob> code;
+ auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(res == SLANG_OK);
+ SLANG_CHECK(code != nullptr);
+ SLANG_CHECK(code->getBufferSize() != 0);
+}