diff options
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 94 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 15 | ||||
| -rw-r--r-- | tests/compute/array-existential-parameter.slang | 48 | ||||
| -rw-r--r-- | tests/compute/array-existential-parameter.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tools/render-test/bind-location.cpp | 25 |
6 files changed, 192 insertions, 8 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 18641d8b7..3b390015b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1244,6 +1244,20 @@ struct IRFieldAddress : IRInst }; +struct IRGetElement : IRInst +{ + IR_LEAF_ISA(getElement); + IRInst* getBase() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } +}; + +struct IRGetElementPtr : IRInst +{ + IR_LEAF_ISA(getElementPtr); + IRInst* getBase() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } +}; + struct IRGetAddress : IRInst { IR_LEAF_ISA(getAddr); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 7cbba9737..d364de3fb 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -530,6 +530,13 @@ struct SpecializationContext maybeSpecializeFieldAddress(as<IRFieldAddress>(inst)); break; + case kIROp_getElement: + maybeSpecializeGetElement(as<IRGetElement>(inst)); + break; + case kIROp_getElementPtr: + maybeSpecializeGetElementAddress(as<IRGetElementPtr>(inst)); + break; + case kIROp_BindExistentialsType: maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst)); break; @@ -1506,6 +1513,11 @@ struct SpecializationContext type = ptrLikeType->getElementType(); goto top; } + else if (auto arrayType = as<IRArrayTypeBase>(type)) + { + type = arrayType->getElementType(); + goto top; + } else if( auto structType = as<IRStructType>(type) ) { UInt count = 0; @@ -1695,6 +1707,82 @@ struct SpecializationContext } } + void maybeSpecializeGetElement(IRGetElement* inst) + { + auto baseArg = inst->getBase(); + if (auto wrapInst = as<IRWrapExistential>(baseArg)) + { + // We have `getElement(wrapExistential(val, ...), index)` + // We need to replace this instruction with + // `wrapExistential(getElement(val, index), ...)` + auto index = inst->getIndex(); + + auto val = wrapInst->getWrappedValue(); + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType(); + + List<IRInst*> slotOperands; + UInt slotOperandCount = wrapInst->getSlotOperandCount(); + + for (UInt ii = 0; ii < slotOperandCount; ++ii) + { + slotOperands.add(wrapInst->getSlotOperand(ii)); + } + + auto newGetElement = builder.emitElementExtract(elementType, val, index); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, newGetElement, slotOperandCount, slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + void maybeSpecializeGetElementAddress(IRGetElementPtr* inst) + { + auto baseArg = inst->getBase(); + if (auto wrapInst = as<IRWrapExistential>(baseArg)) + { + // We have `getElementPtr(wrapExistential(val, ...), index)` + // We need to replace this instruction with + // `wrapExistential(getElementPtr(val, index), ...)` + auto index = inst->getIndex(); + + auto val = wrapInst->getWrappedValue(); + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType(); + + List<IRInst*> slotOperands; + UInt slotOperandCount = wrapInst->getSlotOperandCount(); + + for (UInt ii = 0; ii < slotOperandCount; ++ii) + { + slotOperands.add(wrapInst->getSlotOperand(ii)); + } + + auto newElementAddr = builder.emitElementAddress(elementType, val, index); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, newElementAddr, slotOperandCount, slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + UInt calcExistentialTypeParamSlotCount(IRType* type) { top: @@ -1764,7 +1852,9 @@ struct SpecializationContext type->removeAndDeallocate(); return; } - else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) ) + else if( as<IRPointerLikeType>(baseType) || + as<IRHLSLStructuredBufferTypeBase>(baseType) || + as<IRArrayTypeBase>(baseType)) { // A `BindExistentials<P<T>, ...>` can be simplified to // `P<BindExistentials<T, ...>>` when `P` is a pointer-like @@ -1773,6 +1863,8 @@ struct SpecializationContext IRType* baseElementType = nullptr; if (auto basePtrLikeType = as<IRPointerLikeType>(baseType)) baseElementType = basePtrLikeType->getElementType(); + else if (auto arrayType = as<IRArrayTypeBase>(baseType)) + baseElementType = arrayType->getElementType(); else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType)) baseElementType = baseSBType->getElementType(); diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 46a521f79..88ec5f4f7 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -3489,6 +3489,13 @@ static TypeLayoutResult _createTypeLayout( { arrayResourceCount = elementResourceInfo.count; } + // The second exception to this is arrays of an existential type + // where the entire array should be specialized to a single concrete type. + // + else if (elementResourceInfo.kind == LayoutResourceKind::ExistentialTypeParam) + { + arrayResourceCount = elementResourceInfo.count; + } // // The next big exception is when we are forming an unbounded-size // array and the element type got "adjusted," because that means @@ -3677,6 +3684,7 @@ static TypeLayoutResult _createTypeLayout( typeLayout->rules = rules; LayoutSize fixedExistentialValueSize = 0; + LayoutSize uniformSlotSize = 0; bool targetSupportsPointer = isCPUTarget(context.targetReq) || isCUDATarget(context.targetReq); @@ -3689,7 +3697,7 @@ static TypeLayoutResult _createTypeLayout( fixedExistentialValueSize = anyValueAttr->size; } // Append 16 bytes to accommodate RTTI pointer and witness table pointer. - auto uniformSlotSize = fixedExistentialValueSize + 16; + uniformSlotSize = fixedExistentialValueSize + 16; typeLayout->addResourceUsage(LayoutResourceKind::Uniform, uniformSlotSize); } typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1); @@ -3736,8 +3744,9 @@ static TypeLayoutResult _createTypeLayout( typeLayout->pendingDataTypeLayout = concreteTypeLayout; } } - - return TypeLayoutResult(typeLayout, SimpleLayoutInfo()); + // Interface type occupies a uniform slot for the fixed size storage, with alignment of 4 bytes. + return TypeLayoutResult( + typeLayout, SimpleLayoutInfo(LayoutResourceKind::Uniform, uniformSlotSize, 4)); } else if( auto enumDeclRef = declRef.as<EnumDecl>() ) { diff --git a/tests/compute/array-existential-parameter.slang b/tests/compute/array-existential-parameter.slang new file mode 100644 index 000000000..f640ba752 --- /dev/null +++ b/tests/compute/array-existential-parameter.slang @@ -0,0 +1,48 @@ +// Test using existential shader parameter that is an interface array. + +//TEST(compute):COMPARE_COMPUTE:-cpu +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda + +[anyValueSize(8)] +interface IInterface +{ + int run(int input); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer +RWStructuredBuffer<int> gOutputBuffer; + +struct Params +{ + IInterface values[2]; +}; + +//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb +ConstantBuffer<Params> gCb; + +//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb2 +ConstantBuffer<Params> gCb2; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + let tid = dispatchThreadID.x; + + let inputVal : int = tid; + let outputVal = gCb.values[0].run(inputVal) + gCb2.values[0].run(inputVal); + + gOutputBuffer[tid] = outputVal; +} + +//TEST_INPUT: globalExistentialType MyImpl +//TEST_INPUT: globalExistentialType __Dynamic + +// Type must be marked `public` to ensure it is visible in the generated DLL. +public struct MyImpl : IInterface +{ + int val; + int run(int input) + { + return input + val; + } +}; diff --git a/tests/compute/array-existential-parameter.slang.expected.txt b/tests/compute/array-existential-parameter.slang.expected.txt new file mode 100644 index 000000000..628d2a8fd --- /dev/null +++ b/tests/compute/array-existential-parameter.slang.expected.txt @@ -0,0 +1,4 @@ +2 +4 +6 +8 diff --git a/tools/render-test/bind-location.cpp b/tools/render-test/bind-location.cpp index 4ec590fd5..f791e56f6 100644 --- a/tools/render-test/bind-location.cpp +++ b/tools/render-test/bind-location.cpp @@ -328,6 +328,23 @@ void BindSet::calcValueLocations(const BindLocation& location, Slang::List<BindL } } +// Finds the first category from layout reflection that represents an actual value +// i.e. that is not ExistentialType or ExistentialObject. +template<typename LayoutReflectionType> +slang::ParameterCategory getFirstNonExistentialValueCategory(LayoutReflectionType* layout) +{ + slang::ParameterCategory category = slang::ParameterCategory::None; + for (UInt i = 0; i < layout->getCategoryCount(); i++) + { + auto currentCategory = layout->getCategoryByIndex((unsigned int)i); + if (currentCategory == slang::ParameterCategory::ExistentialTypeParam || + currentCategory == slang::ParameterCategory::ExistentialObjectParam) + continue; + category = currentCategory; + } + return category; +} + BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutReflection* field) const { const Index categoryCount = Index(field->getCategoryCount()); @@ -363,8 +380,8 @@ BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutRefl } else { - SLANG_ASSERT(categoryCount == 1); - auto category = field->getCategoryByIndex(0); + slang::ParameterCategory category = getFirstNonExistentialValueCategory(field); + SLANG_ASSERT(category != slang::ParameterCategory::None); // If I'm going from mixed, then I will have multiple items being tracked (so won't be here) // If I'm not, then I'm getting an inplace field. It must be relative @@ -496,8 +513,8 @@ BindLocation BindSet::toIndex(const BindLocation& loc, Index index) const } else { - SLANG_ASSERT(categoryCount == 1); - auto category = elementTypeLayout->getCategoryByIndex(0); + slang::ParameterCategory category = getFirstNonExistentialValueCategory(elementTypeLayout); + SLANG_ASSERT(category != slang::ParameterCategory::None); const auto elementStride = typeLayout->getElementStride(category); |
