summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-specialize.cpp94
-rw-r--r--source/slang/slang-type-layout.cpp15
-rw-r--r--tests/compute/array-existential-parameter.slang48
-rw-r--r--tests/compute/array-existential-parameter.slang.expected.txt4
-rw-r--r--tools/render-test/bind-location.cpp25
6 files changed, 192 insertions, 8 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 18641d8b7..3b390015b 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1244,6 +1244,20 @@ struct IRFieldAddress : IRInst
};
+struct IRGetElement : IRInst
+{
+ IR_LEAF_ISA(getElement);
+ IRInst* getBase() { return getOperand(0); }
+ IRInst* getIndex() { return getOperand(1); }
+};
+
+struct IRGetElementPtr : IRInst
+{
+ IR_LEAF_ISA(getElementPtr);
+ IRInst* getBase() { return getOperand(0); }
+ IRInst* getIndex() { return getOperand(1); }
+};
+
struct IRGetAddress : IRInst
{
IR_LEAF_ISA(getAddr);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 7cbba9737..d364de3fb 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -530,6 +530,13 @@ struct SpecializationContext
maybeSpecializeFieldAddress(as<IRFieldAddress>(inst));
break;
+ case kIROp_getElement:
+ maybeSpecializeGetElement(as<IRGetElement>(inst));
+ break;
+ case kIROp_getElementPtr:
+ maybeSpecializeGetElementAddress(as<IRGetElementPtr>(inst));
+ break;
+
case kIROp_BindExistentialsType:
maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst));
break;
@@ -1506,6 +1513,11 @@ struct SpecializationContext
type = ptrLikeType->getElementType();
goto top;
}
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ type = arrayType->getElementType();
+ goto top;
+ }
else if( auto structType = as<IRStructType>(type) )
{
UInt count = 0;
@@ -1695,6 +1707,82 @@ struct SpecializationContext
}
}
+ void maybeSpecializeGetElement(IRGetElement* inst)
+ {
+ auto baseArg = inst->getBase();
+ if (auto wrapInst = as<IRWrapExistential>(baseArg))
+ {
+ // We have `getElement(wrapExistential(val, ...), index)`
+ // We need to replace this instruction with
+ // `wrapExistential(getElement(val, index), ...)`
+ auto index = inst->getIndex();
+
+ auto val = wrapInst->getWrappedValue();
+ auto resultType = inst->getFullType();
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(inst);
+
+ auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType();
+
+ List<IRInst*> slotOperands;
+ UInt slotOperandCount = wrapInst->getSlotOperandCount();
+
+ for (UInt ii = 0; ii < slotOperandCount; ++ii)
+ {
+ slotOperands.add(wrapInst->getSlotOperand(ii));
+ }
+
+ auto newGetElement = builder.emitElementExtract(elementType, val, index);
+
+ auto newWrapExistentialInst = builder.emitWrapExistential(
+ resultType, newGetElement, slotOperandCount, slotOperands.getBuffer());
+
+ addUsersToWorkList(inst);
+ inst->replaceUsesWith(newWrapExistentialInst);
+ inst->removeAndDeallocate();
+ }
+ }
+
+ void maybeSpecializeGetElementAddress(IRGetElementPtr* inst)
+ {
+ auto baseArg = inst->getBase();
+ if (auto wrapInst = as<IRWrapExistential>(baseArg))
+ {
+ // We have `getElementPtr(wrapExistential(val, ...), index)`
+ // We need to replace this instruction with
+ // `wrapExistential(getElementPtr(val, index), ...)`
+ auto index = inst->getIndex();
+
+ auto val = wrapInst->getWrappedValue();
+ auto resultType = inst->getFullType();
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(inst);
+
+ auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType();
+
+ List<IRInst*> slotOperands;
+ UInt slotOperandCount = wrapInst->getSlotOperandCount();
+
+ for (UInt ii = 0; ii < slotOperandCount; ++ii)
+ {
+ slotOperands.add(wrapInst->getSlotOperand(ii));
+ }
+
+ auto newElementAddr = builder.emitElementAddress(elementType, val, index);
+
+ auto newWrapExistentialInst = builder.emitWrapExistential(
+ resultType, newElementAddr, slotOperandCount, slotOperands.getBuffer());
+
+ addUsersToWorkList(inst);
+ inst->replaceUsesWith(newWrapExistentialInst);
+ inst->removeAndDeallocate();
+ }
+ }
+
UInt calcExistentialTypeParamSlotCount(IRType* type)
{
top:
@@ -1764,7 +1852,9 @@ struct SpecializationContext
type->removeAndDeallocate();
return;
}
- else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) )
+ else if( as<IRPointerLikeType>(baseType) ||
+ as<IRHLSLStructuredBufferTypeBase>(baseType) ||
+ as<IRArrayTypeBase>(baseType))
{
// A `BindExistentials<P<T>, ...>` can be simplified to
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
@@ -1773,6 +1863,8 @@ struct SpecializationContext
IRType* baseElementType = nullptr;
if (auto basePtrLikeType = as<IRPointerLikeType>(baseType))
baseElementType = basePtrLikeType->getElementType();
+ else if (auto arrayType = as<IRArrayTypeBase>(baseType))
+ baseElementType = arrayType->getElementType();
else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
baseElementType = baseSBType->getElementType();
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 46a521f79..88ec5f4f7 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -3489,6 +3489,13 @@ static TypeLayoutResult _createTypeLayout(
{
arrayResourceCount = elementResourceInfo.count;
}
+ // The second exception to this is arrays of an existential type
+ // where the entire array should be specialized to a single concrete type.
+ //
+ else if (elementResourceInfo.kind == LayoutResourceKind::ExistentialTypeParam)
+ {
+ arrayResourceCount = elementResourceInfo.count;
+ }
//
// The next big exception is when we are forming an unbounded-size
// array and the element type got "adjusted," because that means
@@ -3677,6 +3684,7 @@ static TypeLayoutResult _createTypeLayout(
typeLayout->rules = rules;
LayoutSize fixedExistentialValueSize = 0;
+ LayoutSize uniformSlotSize = 0;
bool targetSupportsPointer =
isCPUTarget(context.targetReq) || isCUDATarget(context.targetReq);
@@ -3689,7 +3697,7 @@ static TypeLayoutResult _createTypeLayout(
fixedExistentialValueSize = anyValueAttr->size;
}
// Append 16 bytes to accommodate RTTI pointer and witness table pointer.
- auto uniformSlotSize = fixedExistentialValueSize + 16;
+ uniformSlotSize = fixedExistentialValueSize + 16;
typeLayout->addResourceUsage(LayoutResourceKind::Uniform, uniformSlotSize);
}
typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1);
@@ -3736,8 +3744,9 @@ static TypeLayoutResult _createTypeLayout(
typeLayout->pendingDataTypeLayout = concreteTypeLayout;
}
}
-
- return TypeLayoutResult(typeLayout, SimpleLayoutInfo());
+ // Interface type occupies a uniform slot for the fixed size storage, with alignment of 4 bytes.
+ return TypeLayoutResult(
+ typeLayout, SimpleLayoutInfo(LayoutResourceKind::Uniform, uniformSlotSize, 4));
}
else if( auto enumDeclRef = declRef.as<EnumDecl>() )
{
diff --git a/tests/compute/array-existential-parameter.slang b/tests/compute/array-existential-parameter.slang
new file mode 100644
index 000000000..f640ba752
--- /dev/null
+++ b/tests/compute/array-existential-parameter.slang
@@ -0,0 +1,48 @@
+// Test using existential shader parameter that is an interface array.
+
+//TEST(compute):COMPARE_COMPUTE:-cpu
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda
+
+[anyValueSize(8)]
+interface IInterface
+{
+ int run(int input);
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
+RWStructuredBuffer<int> gOutputBuffer;
+
+struct Params
+{
+ IInterface values[2];
+};
+
+//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
+ConstantBuffer<Params> gCb;
+
+//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb2
+ConstantBuffer<Params> gCb2;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ let tid = dispatchThreadID.x;
+
+ let inputVal : int = tid;
+ let outputVal = gCb.values[0].run(inputVal) + gCb2.values[0].run(inputVal);
+
+ gOutputBuffer[tid] = outputVal;
+}
+
+//TEST_INPUT: globalExistentialType MyImpl
+//TEST_INPUT: globalExistentialType __Dynamic
+
+// Type must be marked `public` to ensure it is visible in the generated DLL.
+public struct MyImpl : IInterface
+{
+ int val;
+ int run(int input)
+ {
+ return input + val;
+ }
+};
diff --git a/tests/compute/array-existential-parameter.slang.expected.txt b/tests/compute/array-existential-parameter.slang.expected.txt
new file mode 100644
index 000000000..628d2a8fd
--- /dev/null
+++ b/tests/compute/array-existential-parameter.slang.expected.txt
@@ -0,0 +1,4 @@
+2
+4
+6
+8
diff --git a/tools/render-test/bind-location.cpp b/tools/render-test/bind-location.cpp
index 4ec590fd5..f791e56f6 100644
--- a/tools/render-test/bind-location.cpp
+++ b/tools/render-test/bind-location.cpp
@@ -328,6 +328,23 @@ void BindSet::calcValueLocations(const BindLocation& location, Slang::List<BindL
}
}
+// Finds the first category from layout reflection that represents an actual value
+// i.e. that is not ExistentialType or ExistentialObject.
+template<typename LayoutReflectionType>
+slang::ParameterCategory getFirstNonExistentialValueCategory(LayoutReflectionType* layout)
+{
+ slang::ParameterCategory category = slang::ParameterCategory::None;
+ for (UInt i = 0; i < layout->getCategoryCount(); i++)
+ {
+ auto currentCategory = layout->getCategoryByIndex((unsigned int)i);
+ if (currentCategory == slang::ParameterCategory::ExistentialTypeParam ||
+ currentCategory == slang::ParameterCategory::ExistentialObjectParam)
+ continue;
+ category = currentCategory;
+ }
+ return category;
+}
+
BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutReflection* field) const
{
const Index categoryCount = Index(field->getCategoryCount());
@@ -363,8 +380,8 @@ BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutRefl
}
else
{
- SLANG_ASSERT(categoryCount == 1);
- auto category = field->getCategoryByIndex(0);
+ slang::ParameterCategory category = getFirstNonExistentialValueCategory(field);
+ SLANG_ASSERT(category != slang::ParameterCategory::None);
// If I'm going from mixed, then I will have multiple items being tracked (so won't be here)
// If I'm not, then I'm getting an inplace field. It must be relative
@@ -496,8 +513,8 @@ BindLocation BindSet::toIndex(const BindLocation& loc, Index index) const
}
else
{
- SLANG_ASSERT(categoryCount == 1);
- auto category = elementTypeLayout->getCategoryByIndex(0);
+ slang::ParameterCategory category = getFirstNonExistentialValueCategory(elementTypeLayout);
+ SLANG_ASSERT(category != slang::ParameterCategory::None);
const auto elementStride = typeLayout->getElementStride(category);