diff options
Diffstat (limited to 'tools/gfx-unit-test')
3 files changed, 282 insertions, 0 deletions
diff --git a/tools/gfx-unit-test/link-time-constant-array-size-lib.slang b/tools/gfx-unit-test/link-time-constant-array-size-lib.slang new file mode 100644 index 000000000..ab72e8e2f --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size-lib.slang @@ -0,0 +1 @@ +export static const int N = 4; diff --git a/tools/gfx-unit-test/link-time-constant-array-size-main.slang b/tools/gfx-unit-test/link-time-constant-array-size-main.slang new file mode 100644 index 000000000..6c5aea928 --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size-main.slang @@ -0,0 +1,12 @@ +extern static const int N; + +struct S { int xs[N]; } + +RWStructuredBuffer<S> b; + +[numthreads(1, 1, 1)] +void computeMain() +{ + for(int i = 0; i < N; ++i) + b[0].xs[i] *= N; +} diff --git a/tools/gfx-unit-test/link-time-constant-array-size.cpp b/tools/gfx-unit-test/link-time-constant-array-size.cpp new file mode 100644 index 000000000..e67e5636c --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size.cpp @@ -0,0 +1,269 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "gfx-util/shader-cursor.h" +#include "slang-gfx.h" +#include "unit-test/slang-unit-test.h" + +using namespace gfx; + +namespace gfx_test +{ +static Slang::Result loadProgram( + gfx::IDevice* device, + Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram, + const char* mainModuleName, + const char* libModuleName, + const char* entryPointName, + slang::ProgramLayout*& slangReflection) +{ + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + + // Load main module + slang::IModule* mainModule = + slangSession->loadModule(mainModuleName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!mainModule) + return SLANG_FAIL; + + // Load library module with constants + slang::IModule* libModule = slangSession->loadModule(libModuleName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!libModule) + return SLANG_FAIL; + + // Find entry point + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + mainModule->findEntryPointByName(entryPointName, computeEntryPoint.writeRef())); + + // Compose program from modules + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(mainModule); + componentTypes.add(libModule); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + // Link program + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + // Create shader program + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +// Function to validate the array size in struct S +static void validateArraySizeInStruct( + UnitTestContext* context, + slang::ProgramLayout* slangReflection, + int expectedSize) +{ + // Check reflection is available + SLANG_CHECK(slangReflection != nullptr); + + // Get the global scope layout + auto globalScope = slangReflection->getGlobalParamsVarLayout(); + SLANG_CHECK_MSG(globalScope != nullptr, "Could not get global scope layout"); + + auto typeLayout = globalScope->getTypeLayout(); + SLANG_CHECK_MSG(typeLayout != nullptr, "Global scope has no type layout"); + + // Check if the global scope is a struct type + auto kind = typeLayout->getKind(); + SLANG_CHECK_MSG( + kind == slang::TypeReflection::Kind::Struct, + "Global scope is not a struct type"); + + // Find the buffer resource 'b' + bool foundBuffer = false; + auto fieldCount = typeLayout->getFieldCount(); + + for (unsigned int i = 0; i < fieldCount; i++) + { + auto fieldLayout = typeLayout->getFieldByIndex(i); + const char* fieldName = fieldLayout->getName(); + + if (fieldName && strcmp(fieldName, "b") == 0) + { + foundBuffer = true; + + // Get the type layout of the field + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + SLANG_CHECK_MSG(fieldTypeLayout != nullptr, "Field has no type layout"); + + // Get the element type of the structured buffer + auto elementTypeLayout = fieldTypeLayout->getElementTypeLayout(); + SLANG_CHECK_MSG( + elementTypeLayout != nullptr, + "Structured buffer has no element type layout"); + + // Check if it's a struct type + auto elementKind = elementTypeLayout->getKind(); + SLANG_CHECK_MSG( + elementKind == slang::TypeReflection::Kind::Struct, + "Buffer element is not a struct type"); + + // Get the field count of the struct + auto structFieldCount = elementTypeLayout->getFieldCount(); + SLANG_CHECK_MSG(structFieldCount >= 1, "Struct has no fields"); + + // Check for the 'xs' field + bool foundXsField = false; + for (unsigned int j = 0; j < structFieldCount; j++) + { + auto structField = elementTypeLayout->getFieldByIndex(j); + const char* structFieldName = structField->getName(); + + if (structFieldName && strcmp(structFieldName, "xs") == 0) + { + foundXsField = true; + + // Check that it's an array type + auto structFieldTypeLayout = structField->getTypeLayout(); + auto structFieldTypeKind = structFieldTypeLayout->getKind(); + + SLANG_CHECK_MSG( + structFieldTypeKind == slang::TypeReflection::Kind::Array, + "Field 'xs' is not an array type"); + + // Check the array size + auto arraySize = structFieldTypeLayout->getElementCount(); + // 0 becuase we haven't resolved the constant + SLANG_CHECK_MSG( + arraySize == 0, + "Field 'xs' array size does not match expected size"); + + // 4 because we're resolving it + const auto resolvedArraySize = + structFieldTypeLayout->getElementCount(slangReflection); + SLANG_CHECK_MSG( + resolvedArraySize == expectedSize, + "Field 'xs' array size does not match expected size"); + + break; + } + } + + SLANG_CHECK_MSG(foundXsField, "Could not find field 'xs' in struct S"); + break; + } + } + + SLANG_CHECK_MSG(foundBuffer, "Could not find buffer 'b' in global scope"); +} + + +void linkTimeConstantArraySizeTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create transient heap + Slang::ComPtr<ITransientResourceHeap> transientHeap; + ITransientResourceHeap::Desc transientHeapDesc = {}; + transientHeapDesc.constantBufferSize = 4096; + GFX_CHECK_CALL_ABORT( + device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef())); + + // Load and link program + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram( + device, + shaderProgram, + "link-time-constant-array-size-main", + "link-time-constant-array-size-lib", + "computeMain", + slangReflection)); + + // Check array size through reflection + const int N = 4; // This should match the constant in lib.slang + + validateArraySizeInStruct(context, slangReflection, N); + + // Create compute pipeline + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<gfx::IPipelineState> pipelineState; + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + // Create buffer for struct S with array of size N + int32_t initialData[] = {1, 2, 3, 4}; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = N * sizeof(int32_t); + bufferDesc.format = gfx::Format::Unknown; + bufferDesc.elementSize = sizeof(int32_t); + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::ShaderResource, + ResourceState::UnorderedAccess, + ResourceState::CopyDestination, + ResourceState::CopySource); + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBufferResource> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBufferResource(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + ComPtr<IResourceView> bufferView; + IResourceView::Desc viewDesc = {}; + viewDesc.type = IResourceView::Type::UnorderedAccess; + viewDesc.format = Format::Unknown; + GFX_CHECK_CALL_ABORT( + device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef())); + + // Record and execute command buffer + { + ICommandQueue::Desc queueDesc = {ICommandQueue::QueueType::Graphics}; + auto queue = device->createCommandQueue(queueDesc); + + auto commandBuffer = transientHeap->createCommandBuffer(); + auto encoder = commandBuffer->encodeComputeCommands(); + + auto rootObject = encoder->bindPipeline(pipelineState); + + ShaderCursor rootCursor(rootObject); + rootCursor.getPath("b").setResource(bufferView); + + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + commandBuffer->close(); + queue->executeCommandBuffer(commandBuffer); + queue->waitOnHost(); + } + + // Expected results: each element is input * N + // With N=4 and inputs [1,2,3,4], expected output is [4,8,12,16] + compareComputeResult(device, numbersBuffer, Slang::makeArray<int>(4, 8, 12, 16)); +} + +SLANG_UNIT_TEST(linkTimeConstantArraySizeD3D12) +{ + runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); +} + +SLANG_UNIT_TEST(linkTimeConstantArraySizeVulkan) +{ + runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); +} + +} // namespace gfx_test |
