summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJames Helferty (NVIDIA) <jhelferty@nvidia.com>2025-06-26 14:15:43 -0400
committerGitHub <noreply@github.com>2025-06-26 18:15:43 +0000
commit78dc7c3fd1441ba756c969b81ed90600fa9f9f38 (patch)
tree61e778010669795c36c61aa8433d8311813af3bb
parent1be819f4f0dbd001d1b222d7461a4fd87452dee2 (diff)
Hide atomics struct from reflection api (#7520)
* Atomic reflection unit test Test that Atomic<int> is reflected as a scalar/int instead of a struct named Atomic. * Hide AtomicType from reflection Leverage the fact that returned variables go through convert() to centralize the unwrapping of the AtomicType struct. Fixes #6257 * Clean up unit test a bit - rename bar to buf (since one of them is not Atomic and thus not a barrier) - validate deeper into the fields of bufC - remove debug code
-rw-r--r--source/slang/slang-reflection-api.cpp14
-rw-r--r--tools/slang-unit-test/unit-test-atomic-reflection.cpp135
2 files changed, 145 insertions, 4 deletions
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 8688f5e41..b0d88e954 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -40,6 +40,12 @@ static inline Type* convert(SlangReflectionType* type)
static inline SlangReflectionType* convert(Type* type)
{
+ // Prevent the AtomicType struct from being visible to the user
+ // through the reflection API.
+ if (auto atomicType = as<AtomicType>(type))
+ {
+ return (SlangReflectionType*)atomicType->getElementType();
+ }
return (SlangReflectionType*)type;
}
@@ -586,7 +592,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy
if (auto arrayType = as<ArrayExpressionType>(type))
{
- return (SlangReflectionType*)arrayType->getElementType();
+ return convert(arrayType->getElementType());
}
else if (auto parameterGroupType = as<ParameterGroupType>(type))
{
@@ -1026,7 +1032,7 @@ SLANG_API SlangReflectionType* spReflection_FindTypeByName(
if (as<ErrorType>(result))
return nullptr;
- return (SlangReflectionType*)result;
+ return convert(result);
}
catch (...)
{
@@ -1158,7 +1164,7 @@ SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(
if (!typeLayout)
return nullptr;
- return (SlangReflectionType*)typeLayout->type;
+ return convert(typeLayout->type);
}
SLANG_API SlangTypeKind spReflectionTypeLayout_getKind(SlangReflectionTypeLayout* inTypeLayout)
@@ -4227,7 +4233,7 @@ SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(
{
auto constraints =
globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>();
- return (SlangReflectionType*)constraints[index]->sup.Ptr();
+ return convert(constraints[index]->sup.Ptr());
}
// TODO: Add case for entry-point generic parameters.
}
diff --git a/tools/slang-unit-test/unit-test-atomic-reflection.cpp b/tools/slang-unit-test/unit-test-atomic-reflection.cpp
new file mode 100644
index 000000000..9817f2aa6
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-atomic-reflection.cpp
@@ -0,0 +1,135 @@
+// unit-test-atomic-reflection.cpp
+
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+#include "unit-test/slang-unit-test.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+using namespace Slang;
+
+SLANG_UNIT_TEST(atomicReflection)
+{
+ const char* userSourceBody = R"(
+ RWStructuredBuffer<uint> bufA; // for reference
+ RWStructuredBuffer<Atomic<uint>> bufB;
+ struct TestStruct {
+ Atomic<uint> fieldA;
+ Atomic<uint> fieldB[2];
+ vector<Atomic<uint>, 2> fieldC;
+ };
+ RWStructuredBuffer<TestStruct> bufC;
+
+ [shader("compute")]
+ [numthreads(64, 1, 1)]
+ void main(uint2 dispatchThreadId : SV_DispatchThreadID)
+ {
+ }
+ )";
+ String userSource = userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto reflection = module->getLayout();
+
+ auto bufAVariableLayout = reflection->getParameterByIndex(0);
+ auto bufBVariableLayout = reflection->getParameterByIndex(1);
+ auto bufCVariableLayout = reflection->getParameterByIndex(2);
+
+ SLANG_CHECK(bufAVariableLayout != nullptr);
+ SLANG_CHECK(bufBVariableLayout != nullptr);
+ SLANG_CHECK(bufCVariableLayout != nullptr);
+
+ // This test walks down the parameters in a different way from the JSON
+ // option. The slangc option "-reflection-json" walks:
+ // param->getTypeLayout()->getElementTypeLayout()->getType()->getKind()
+ // This test uses an alternate walk:
+ // param->getVariable()->getType()->getElementType()->getKind()
+
+ auto bufAVariable = bufAVariableLayout->getVariable();
+ auto bufBVariable = bufBVariableLayout->getVariable();
+ auto bufCVariable = bufCVariableLayout->getVariable();
+
+ SLANG_CHECK(bufAVariable != nullptr);
+ SLANG_CHECK(bufBVariable != nullptr);
+ SLANG_CHECK(bufCVariable != nullptr);
+
+ auto bufAType = bufAVariable->getType();
+ auto bufBType = bufBVariable->getType();
+ auto bufCType = bufCVariable->getType();
+
+ SLANG_CHECK(bufAType->getKind() == slang::TypeReflection::Kind::Resource);
+ SLANG_CHECK(bufBType->getKind() == slang::TypeReflection::Kind::Resource);
+ SLANG_CHECK(bufCType->getKind() == slang::TypeReflection::Kind::Resource);
+
+ auto bufAElementType = bufAType->getElementType();
+ auto bufBElementType = bufBType->getElementType();
+ auto bufCElementType = bufCType->getElementType();
+
+ SLANG_CHECK(bufAElementType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(bufBElementType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(bufAElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+ SLANG_CHECK(bufBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+
+ auto bufAResResultType = bufAType->getResourceResultType();
+ auto bufBResResultType = bufBType->getResourceResultType();
+
+ SLANG_CHECK(bufAResResultType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(bufBResResultType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(bufAResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+ SLANG_CHECK(bufBResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+
+ // Atomics embedded in structs require traversing the fields
+ SLANG_CHECK(bufCElementType->getKind() == slang::TypeReflection::Kind::Struct);
+
+ auto bufCFieldCount = bufCElementType->getFieldCount();
+ SLANG_CHECK(bufCFieldCount == 3);
+
+ auto fieldA = bufCElementType->getFieldByIndex(0);
+ auto fieldB = bufCElementType->getFieldByIndex(1);
+ auto fieldC = bufCElementType->getFieldByIndex(2);
+
+ SLANG_CHECK(fieldA != nullptr);
+ SLANG_CHECK(fieldB != nullptr);
+ SLANG_CHECK(fieldC != nullptr);
+
+ auto fieldAType = fieldA->getType();
+ auto fieldBType = fieldB->getType();
+ auto fieldCType = fieldC->getType();
+
+ SLANG_CHECK(fieldAType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(fieldAType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+
+ SLANG_CHECK(fieldBType->getKind() == slang::TypeReflection::Kind::Array);
+ SLANG_CHECK(fieldBType->getElementCount() == 2);
+ auto fieldBElementType = fieldBType->getElementType();
+ SLANG_CHECK(fieldBElementType != nullptr);
+ SLANG_CHECK(fieldBElementType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(fieldBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+
+ SLANG_CHECK(fieldCType->getKind() == slang::TypeReflection::Kind::Vector);
+ SLANG_CHECK(fieldCType->getElementCount() == 2);
+ auto fieldCElementType = fieldCType->getElementType();
+ SLANG_CHECK(fieldCElementType != nullptr);
+ SLANG_CHECK(fieldCElementType->getKind() == slang::TypeReflection::Kind::Scalar);
+ SLANG_CHECK(fieldCElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
+}