diff options
| -rw-r--r-- | include/slang-deprecated.h | 10 | ||||
| -rw-r--r-- | include/slang.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 36 | ||||
| -rw-r--r-- | tests/bugs/link-time-constant-array-size-lib.slang | 1 | ||||
| -rw-r--r-- | tests/bugs/link-time-constant-array-size-main.slang | 22 | ||||
| -rw-r--r-- | tests/library/ambiguous-extern-export-entry.slang.expected | 17 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-constant-array-size-lib.slang | 1 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-constant-array-size-main.slang | 12 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-constant-array-size.cpp | 269 |
11 files changed, 385 insertions, 7 deletions
diff --git a/include/slang-deprecated.h b/include/slang-deprecated.h index 2ae91c6d8..f210e8c48 100644 --- a/include/slang-deprecated.h +++ b/include/slang-deprecated.h @@ -478,9 +478,19 @@ extern "C" If the size of a type cannot be statically computed, perhaps because it depends on a generic parameter that has not been bound to a specific value, this function returns zero. + + Use spReflectionType_GetSpecializedElementCount if the size is dependent on + a link time constant */ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* type); + /** The same as spReflectionType_GetElementCount except it takes into account specialization + * information from the given reflection info + */ + SLANG_API size_t spReflectionType_GetSpecializedElementCount( + SlangReflectionType* type, + SlangReflection* reflection); + SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* type); SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* type); diff --git a/include/slang.h b/include/slang.h index 782c4a082..4a3af7de1 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2291,9 +2291,9 @@ struct TypeReflection } // only useful if `getKind() == Kind::Array` - size_t getElementCount() + size_t getElementCount(SlangReflection* reflection = nullptr) { - return spReflectionType_GetElementCount((SlangReflectionType*)this); + return spReflectionType_GetSpecializedElementCount((SlangReflectionType*)this, reflection); } size_t getTotalArrayElementCount() @@ -2454,6 +2454,8 @@ enum class BindingType : SlangBindingTypeIntegral ExtMask = SLANG_BINDING_TYPE_EXT_MASK, }; +struct ShaderReflection; + struct TypeLayoutReflection { TypeReflection* getType() @@ -2543,7 +2545,10 @@ struct TypeLayoutReflection } // only useful if `getKind() == Kind::Array` - size_t getElementCount() { return getType()->getElementCount(); } + size_t getElementCount(ShaderReflection* reflection = nullptr) + { + return getType()->getElementCount((SlangReflection*)reflection); + } size_t getTotalArrayElementCount() { return getType()->getTotalArrayElementCount(); } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0dd859bb2..9e4d5a6d3 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10468,6 +10468,11 @@ void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl) getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); return; } + + if (elementCount->isLinkTimeVal()) + { + getSink()->diagnose(varDecl, Diagnostics::linkTimeConstantArraySize); + } } bool getExtensionTargetDeclList( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 384a81f9b..8efdf1d91 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1587,6 +1587,14 @@ DIAGNOSTIC( switchDuplicateCases, "duplicate cases not allowed within a 'switch' statement") +// 310xx: link time specializaion +DIAGNOSTIC( + 31000, + Warning, + linkTimeConstantArraySize, + "Link-time constant sized arrays are a work in progress feature, some aspects of the " + "reflection API may not work") + // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC( diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index b0d88e954..56a82e17e 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4,6 +4,7 @@ #include "slang-check-impl.h" #include "slang-check.h" #include "slang-compiler.h" +#include "slang-deprecated.h" #include "slang-syntax.h" #include "slang-type-layout.h" #include "slang.h" @@ -568,20 +569,45 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex( SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) { + return spReflectionType_GetSpecializedElementCount(inType, nullptr); +} + +SLANG_API size_t spReflectionType_GetSpecializedElementCount( + SlangReflectionType* inType, + SlangReflection* reflection) +{ auto type = convert(inType); if (!type) return 0; + IntVal* elementCount; + bool isUnsized; if (auto arrayType = as<ArrayExpressionType>(type)) { - return !arrayType->isUnsized() ? (size_t)getIntVal(arrayType->getElementCount()) : 0; + elementCount = arrayType->getElementCount(); + isUnsized = arrayType->isUnsized(); } else if (auto vectorType = as<VectorExpressionType>(type)) { - return (size_t)getIntVal(vectorType->getElementCount()); + elementCount = vectorType->getElementCount(); + isUnsized = false; + } + else + { + return 0; } - return 0; + if (const auto program = convert(reflection)) + { + if (const auto componentType = program->getProgram()) + { + if (const auto c = componentType->tryFoldIntVal(elementCount)) + return c->getValue(); + } + } + + const auto isWithoutSize = isUnsized || elementCount->isLinkTimeVal(); + return isWithoutSize ? 0 : (size_t)getIntVal(elementCount); } SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* inType) @@ -1945,7 +1971,9 @@ struct ExtendedTypeLayoutContext LayoutSize elementCount = LayoutSize::infinite(); if (auto arrayType = as<ArrayExpressionType>(arrayTypeLayout->type)) { - if (!arrayType->isUnsized()) + const auto isWithoutSize = + arrayType->isUnsized() || arrayType->getElementCount()->isLinkTimeVal(); + if (!isWithoutSize) { elementCount = LayoutSize::RawValue(getIntVal(arrayType->getElementCount())); } diff --git a/tests/bugs/link-time-constant-array-size-lib.slang b/tests/bugs/link-time-constant-array-size-lib.slang new file mode 100644 index 000000000..cc83c30fa --- /dev/null +++ b/tests/bugs/link-time-constant-array-size-lib.slang @@ -0,0 +1 @@ +export static const int N = 1597463007; diff --git a/tests/bugs/link-time-constant-array-size-main.slang b/tests/bugs/link-time-constant-array-size-main.slang new file mode 100644 index 000000000..da58decaf --- /dev/null +++ b/tests/bugs/link-time-constant-array-size-main.slang @@ -0,0 +1,22 @@ +//TEST:COMPILE: tests/bugs/link-time-constant-array-size-lib.slang -o tests/bugs/link-time-constant-array-size-lib.slang-module +//TEST:COMPILE: tests/bugs/link-time-constant-array-size-main.slang -o tests/bugs/link-time-constant-array-size-main.slang-module +//TEST:SIMPLE(filecheck=SPIRV): -r tests/bugs/link-time-constant-array-size-main.slang-module -r tests/bugs/link-time-constant-array-size-lib.slang-module -target spirv -o out.spv -stage compute -entry computeMain + +extern static const int N; + +// SPIRV: ([[# @LINE+1]]): warning 31000 +struct S { int xs[N]; } + +RWStructuredBuffer<S> b; + +ParameterBlock<S> p; + +[numthreads(1, 1, 1)] +void computeMain() +{ + // check that we multiply by our special number + // SPIRV: [[fisqr:%[a-zA-Z0-9_]+]] = OpConstant %int 1597463007 + // SPIRV: {{%[0-9]+}} = OpIMul %int {{%[0-9]+}} [[fisqr]] + for(int i = 0; i < N; ++i) + b[0].xs[i] = p.xs[i] * N; +} diff --git a/tests/library/ambiguous-extern-export-entry.slang.expected b/tests/library/ambiguous-extern-export-entry.slang.expected new file mode 100644 index 000000000..4ee9ac6c5 --- /dev/null +++ b/tests/library/ambiguous-extern-export-entry.slang.expected @@ -0,0 +1,17 @@ +result code = 0 +standard error = { +tests/library/ambiguous-extern-export-lib1.slang(4): warning 31000: Link-time constant sized arrays are a work in progress feature, some aspects of the reflection API may not work +public extern static const int[call_data_len] call_group_vector; + ^~~~~~~~~~~~~~~~~ +tests/library/ambiguous-extern-export-lib1.slang(5): warning 31000: Link-time constant sized arrays are a work in progress feature, some aspects of the reflection API may not work +public static int[call_data_len] call_id_1 = {}; + ^~~~~~~~~ +tests/library/ambiguous-extern-export-lib2.slang(5): warning 31000: Link-time constant sized arrays are a work in progress feature, some aspects of the reflection API may not work +public extern static const int[call_data_len] call_group_vector; + ^~~~~~~~~~~~~~~~~ +tests/library/ambiguous-extern-export-lib2.slang(7): warning 31000: Link-time constant sized arrays are a work in progress feature, some aspects of the reflection API may not work +public static int[call_data_len] call_id_2 = {}; + ^~~~~~~~~ +} +standard output = { +} diff --git a/tools/gfx-unit-test/link-time-constant-array-size-lib.slang b/tools/gfx-unit-test/link-time-constant-array-size-lib.slang new file mode 100644 index 000000000..ab72e8e2f --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size-lib.slang @@ -0,0 +1 @@ +export static const int N = 4; diff --git a/tools/gfx-unit-test/link-time-constant-array-size-main.slang b/tools/gfx-unit-test/link-time-constant-array-size-main.slang new file mode 100644 index 000000000..6c5aea928 --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size-main.slang @@ -0,0 +1,12 @@ +extern static const int N; + +struct S { int xs[N]; } + +RWStructuredBuffer<S> b; + +[numthreads(1, 1, 1)] +void computeMain() +{ + for(int i = 0; i < N; ++i) + b[0].xs[i] *= N; +} diff --git a/tools/gfx-unit-test/link-time-constant-array-size.cpp b/tools/gfx-unit-test/link-time-constant-array-size.cpp new file mode 100644 index 000000000..e67e5636c --- /dev/null +++ b/tools/gfx-unit-test/link-time-constant-array-size.cpp @@ -0,0 +1,269 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "gfx-util/shader-cursor.h" +#include "slang-gfx.h" +#include "unit-test/slang-unit-test.h" + +using namespace gfx; + +namespace gfx_test +{ +static Slang::Result loadProgram( + gfx::IDevice* device, + Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram, + const char* mainModuleName, + const char* libModuleName, + const char* entryPointName, + slang::ProgramLayout*& slangReflection) +{ + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + + // Load main module + slang::IModule* mainModule = + slangSession->loadModule(mainModuleName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!mainModule) + return SLANG_FAIL; + + // Load library module with constants + slang::IModule* libModule = slangSession->loadModule(libModuleName, diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if (!libModule) + return SLANG_FAIL; + + // Find entry point + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + mainModule->findEntryPointByName(entryPointName, computeEntryPoint.writeRef())); + + // Compose program from modules + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(mainModule); + componentTypes.add(libModule); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + // Link program + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + // Create shader program + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +// Function to validate the array size in struct S +static void validateArraySizeInStruct( + UnitTestContext* context, + slang::ProgramLayout* slangReflection, + int expectedSize) +{ + // Check reflection is available + SLANG_CHECK(slangReflection != nullptr); + + // Get the global scope layout + auto globalScope = slangReflection->getGlobalParamsVarLayout(); + SLANG_CHECK_MSG(globalScope != nullptr, "Could not get global scope layout"); + + auto typeLayout = globalScope->getTypeLayout(); + SLANG_CHECK_MSG(typeLayout != nullptr, "Global scope has no type layout"); + + // Check if the global scope is a struct type + auto kind = typeLayout->getKind(); + SLANG_CHECK_MSG( + kind == slang::TypeReflection::Kind::Struct, + "Global scope is not a struct type"); + + // Find the buffer resource 'b' + bool foundBuffer = false; + auto fieldCount = typeLayout->getFieldCount(); + + for (unsigned int i = 0; i < fieldCount; i++) + { + auto fieldLayout = typeLayout->getFieldByIndex(i); + const char* fieldName = fieldLayout->getName(); + + if (fieldName && strcmp(fieldName, "b") == 0) + { + foundBuffer = true; + + // Get the type layout of the field + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + SLANG_CHECK_MSG(fieldTypeLayout != nullptr, "Field has no type layout"); + + // Get the element type of the structured buffer + auto elementTypeLayout = fieldTypeLayout->getElementTypeLayout(); + SLANG_CHECK_MSG( + elementTypeLayout != nullptr, + "Structured buffer has no element type layout"); + + // Check if it's a struct type + auto elementKind = elementTypeLayout->getKind(); + SLANG_CHECK_MSG( + elementKind == slang::TypeReflection::Kind::Struct, + "Buffer element is not a struct type"); + + // Get the field count of the struct + auto structFieldCount = elementTypeLayout->getFieldCount(); + SLANG_CHECK_MSG(structFieldCount >= 1, "Struct has no fields"); + + // Check for the 'xs' field + bool foundXsField = false; + for (unsigned int j = 0; j < structFieldCount; j++) + { + auto structField = elementTypeLayout->getFieldByIndex(j); + const char* structFieldName = structField->getName(); + + if (structFieldName && strcmp(structFieldName, "xs") == 0) + { + foundXsField = true; + + // Check that it's an array type + auto structFieldTypeLayout = structField->getTypeLayout(); + auto structFieldTypeKind = structFieldTypeLayout->getKind(); + + SLANG_CHECK_MSG( + structFieldTypeKind == slang::TypeReflection::Kind::Array, + "Field 'xs' is not an array type"); + + // Check the array size + auto arraySize = structFieldTypeLayout->getElementCount(); + // 0 becuase we haven't resolved the constant + SLANG_CHECK_MSG( + arraySize == 0, + "Field 'xs' array size does not match expected size"); + + // 4 because we're resolving it + const auto resolvedArraySize = + structFieldTypeLayout->getElementCount(slangReflection); + SLANG_CHECK_MSG( + resolvedArraySize == expectedSize, + "Field 'xs' array size does not match expected size"); + + break; + } + } + + SLANG_CHECK_MSG(foundXsField, "Could not find field 'xs' in struct S"); + break; + } + } + + SLANG_CHECK_MSG(foundBuffer, "Could not find buffer 'b' in global scope"); +} + + +void linkTimeConstantArraySizeTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create transient heap + Slang::ComPtr<ITransientResourceHeap> transientHeap; + ITransientResourceHeap::Desc transientHeapDesc = {}; + transientHeapDesc.constantBufferSize = 4096; + GFX_CHECK_CALL_ABORT( + device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef())); + + // Load and link program + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram( + device, + shaderProgram, + "link-time-constant-array-size-main", + "link-time-constant-array-size-lib", + "computeMain", + slangReflection)); + + // Check array size through reflection + const int N = 4; // This should match the constant in lib.slang + + validateArraySizeInStruct(context, slangReflection, N); + + // Create compute pipeline + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<gfx::IPipelineState> pipelineState; + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + // Create buffer for struct S with array of size N + int32_t initialData[] = {1, 2, 3, 4}; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = N * sizeof(int32_t); + bufferDesc.format = gfx::Format::Unknown; + bufferDesc.elementSize = sizeof(int32_t); + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::ShaderResource, + ResourceState::UnorderedAccess, + ResourceState::CopyDestination, + ResourceState::CopySource); + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBufferResource> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBufferResource(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + ComPtr<IResourceView> bufferView; + IResourceView::Desc viewDesc = {}; + viewDesc.type = IResourceView::Type::UnorderedAccess; + viewDesc.format = Format::Unknown; + GFX_CHECK_CALL_ABORT( + device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef())); + + // Record and execute command buffer + { + ICommandQueue::Desc queueDesc = {ICommandQueue::QueueType::Graphics}; + auto queue = device->createCommandQueue(queueDesc); + + auto commandBuffer = transientHeap->createCommandBuffer(); + auto encoder = commandBuffer->encodeComputeCommands(); + + auto rootObject = encoder->bindPipeline(pipelineState); + + ShaderCursor rootCursor(rootObject); + rootCursor.getPath("b").setResource(bufferView); + + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + commandBuffer->close(); + queue->executeCommandBuffer(commandBuffer); + queue->waitOnHost(); + } + + // Expected results: each element is input * N + // With N=4 and inputs [1,2,3,4], expected output is [4,8,12,16] + compareComputeResult(device, numbersBuffer, Slang::makeArray<int>(4, 8, 12, 16)); +} + +SLANG_UNIT_TEST(linkTimeConstantArraySizeD3D12) +{ + runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); +} + +SLANG_UNIT_TEST(linkTimeConstantArraySizeVulkan) +{ + runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); +} + +} // namespace gfx_test |
