diff options
| author | James Helferty (NVIDIA) <jhelferty@nvidia.com> | 2025-06-26 14:15:43 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-26 18:15:43 +0000 |
| commit | 78dc7c3fd1441ba756c969b81ed90600fa9f9f38 (patch) | |
| tree | 61e778010669795c36c61aa8433d8311813af3bb /tools/slang-unit-test/unit-test-atomic-reflection.cpp | |
| parent | 1be819f4f0dbd001d1b222d7461a4fd87452dee2 (diff) | |
Hide atomics struct from reflection api (#7520)
* Atomic reflection unit test
Test that Atomic<int> is reflected as a scalar/int instead of a struct
named Atomic.
* Hide AtomicType from reflection
Leverage the fact that returned variables go through convert() to
centralize the unwrapping of the AtomicType struct.
Fixes #6257
* Clean up unit test a bit
- rename bar to buf (since one of them is not Atomic and thus not a
barrier)
- validate deeper into the fields of bufC
- remove debug code
Diffstat (limited to 'tools/slang-unit-test/unit-test-atomic-reflection.cpp')
| -rw-r--r-- | tools/slang-unit-test/unit-test-atomic-reflection.cpp | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-atomic-reflection.cpp b/tools/slang-unit-test/unit-test-atomic-reflection.cpp new file mode 100644 index 000000000..9817f2aa6 --- /dev/null +++ b/tools/slang-unit-test/unit-test-atomic-reflection.cpp @@ -0,0 +1,135 @@ +// unit-test-atomic-reflection.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +SLANG_UNIT_TEST(atomicReflection) +{ + const char* userSourceBody = R"( + RWStructuredBuffer<uint> bufA; // for reference + RWStructuredBuffer<Atomic<uint>> bufB; + struct TestStruct { + Atomic<uint> fieldA; + Atomic<uint> fieldB[2]; + vector<Atomic<uint>, 2> fieldC; + }; + RWStructuredBuffer<TestStruct> bufC; + + [shader("compute")] + [numthreads(64, 1, 1)] + void main(uint2 dispatchThreadId : SV_DispatchThreadID) + { + } + )"; + String userSource = userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto reflection = module->getLayout(); + + auto bufAVariableLayout = reflection->getParameterByIndex(0); + auto bufBVariableLayout = reflection->getParameterByIndex(1); + auto bufCVariableLayout = reflection->getParameterByIndex(2); + + SLANG_CHECK(bufAVariableLayout != nullptr); + SLANG_CHECK(bufBVariableLayout != nullptr); + SLANG_CHECK(bufCVariableLayout != nullptr); + + // This test walks down the parameters in a different way from the JSON + // option. The slangc option "-reflection-json" walks: + // param->getTypeLayout()->getElementTypeLayout()->getType()->getKind() + // This test uses an alternate walk: + // param->getVariable()->getType()->getElementType()->getKind() + + auto bufAVariable = bufAVariableLayout->getVariable(); + auto bufBVariable = bufBVariableLayout->getVariable(); + auto bufCVariable = bufCVariableLayout->getVariable(); + + SLANG_CHECK(bufAVariable != nullptr); + SLANG_CHECK(bufBVariable != nullptr); + SLANG_CHECK(bufCVariable != nullptr); + + auto bufAType = bufAVariable->getType(); + auto bufBType = bufBVariable->getType(); + auto bufCType = bufCVariable->getType(); + + SLANG_CHECK(bufAType->getKind() == slang::TypeReflection::Kind::Resource); + SLANG_CHECK(bufBType->getKind() == slang::TypeReflection::Kind::Resource); + SLANG_CHECK(bufCType->getKind() == slang::TypeReflection::Kind::Resource); + + auto bufAElementType = bufAType->getElementType(); + auto bufBElementType = bufBType->getElementType(); + auto bufCElementType = bufCType->getElementType(); + + SLANG_CHECK(bufAElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufBElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufAElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + SLANG_CHECK(bufBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + auto bufAResResultType = bufAType->getResourceResultType(); + auto bufBResResultType = bufBType->getResourceResultType(); + + SLANG_CHECK(bufAResResultType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufBResResultType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufAResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + SLANG_CHECK(bufBResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + // Atomics embedded in structs require traversing the fields + SLANG_CHECK(bufCElementType->getKind() == slang::TypeReflection::Kind::Struct); + + auto bufCFieldCount = bufCElementType->getFieldCount(); + SLANG_CHECK(bufCFieldCount == 3); + + auto fieldA = bufCElementType->getFieldByIndex(0); + auto fieldB = bufCElementType->getFieldByIndex(1); + auto fieldC = bufCElementType->getFieldByIndex(2); + + SLANG_CHECK(fieldA != nullptr); + SLANG_CHECK(fieldB != nullptr); + SLANG_CHECK(fieldC != nullptr); + + auto fieldAType = fieldA->getType(); + auto fieldBType = fieldB->getType(); + auto fieldCType = fieldC->getType(); + + SLANG_CHECK(fieldAType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldAType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + SLANG_CHECK(fieldBType->getKind() == slang::TypeReflection::Kind::Array); + SLANG_CHECK(fieldBType->getElementCount() == 2); + auto fieldBElementType = fieldBType->getElementType(); + SLANG_CHECK(fieldBElementType != nullptr); + SLANG_CHECK(fieldBElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + SLANG_CHECK(fieldCType->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(fieldCType->getElementCount() == 2); + auto fieldCElementType = fieldCType->getElementType(); + SLANG_CHECK(fieldCElementType != nullptr); + SLANG_CHECK(fieldCElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldCElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); +} |
