summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-specialize.cpp49
-rw-r--r--source/slang/slang-type-layout.cpp13
-rw-r--r--tests/compute/interface-func-param-in-struct.slang43
-rw-r--r--tests/compute/interface-func-param-in-struct.slang.expected.txt4
4 files changed, 102 insertions, 7 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index d364de3fb..ceba4b03c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -808,6 +808,24 @@ struct SpecializationContext
return false;
}
+ /// Used by `maybeSpecailizeBufferLoadCall`, this function returns a new specialized callee that
+ /// replaces a `specialize(.operator[], oldType)` to `specialize(.operator[], newElementType)`.
+ IRInst* getNewSpecializedBufferLoadCallee(
+ IRInst* oldSpecializedCallee,
+ IRType* newContainerType,
+ IRType* newElementType)
+ {
+ auto oldSpecialize = cast<IRSpecialize>(oldSpecializedCallee);
+ SLANG_ASSERT(oldSpecialize->getArgCount() == 1);
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(oldSpecializedCallee);
+ auto calleeType = builder.getFuncType(1, &newContainerType, newElementType);
+ auto newSpecialize = builder.emitSpecializeInst(
+ calleeType, oldSpecialize->getBase(), 1, (IRInst**)&newElementType);
+ return newSpecialize;
+ }
+
/// Transform a buffer load intrinsic call.
/// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into
/// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`.
@@ -844,11 +862,18 @@ struct SpecializationContext
{
slotOperands.add(wrapExistential->getSlotOperand(ii));
}
- auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args);
+ // The old callee should be in the form of `specialize(.operator[], IInterfaceType)`,
+ // we should update it to be `specialize(.operator[], elementType)`, so the return type
+ // of the load call is `elementType`.
+ auto oldCallee = inst->getCallee();
+ auto newCallee = getNewSpecializedBufferLoadCallee(inst->getCallee(), sbType, elementType);
+ auto newCall = builder.emitCallInst(elementType, newCallee, args);
auto newWrapExistential = builder.emitWrapExistential(
resultType, newCall, slotOperandCount, slotOperands.getBuffer());
inst->replaceUsesWith(newWrapExistential);
inst->removeAndDeallocate();
+ SLANG_ASSERT(!oldCallee->hasUses());
+ oldCallee->removeAndDeallocate();
addUsersToWorkList(newWrapExistential);
return true;
}
@@ -1080,7 +1105,8 @@ struct SpecializationContext
//
if(as<IRInterfaceType>(type))
return true;
-
+ if (calcExistentialTypeParamSlotCount(type) != 0)
+ return true;
// Eventually we will also want to handle arrays over
// existential types, but that will require careful
// handling in many places.
@@ -1518,6 +1544,11 @@ struct SpecializationContext
type = arrayType->getElementType();
goto top;
}
+ else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type))
+ {
+ type = sbType->getElementType();
+ goto top;
+ }
else if( auto structType = as<IRStructType>(type) )
{
UInt count = 0;
@@ -1800,6 +1831,11 @@ struct SpecializationContext
type = ptrLikeType->getElementType();
goto top;
}
+ else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type))
+ {
+ type = sbType->getElementType();
+ goto top;
+ }
else if( auto structType = as<IRStructType>(type) )
{
UInt count = 0;
@@ -1872,15 +1908,15 @@ struct SpecializationContext
baseElementType,
slotOperandCount,
type->getExistentialArgs());
- addToWorkList(wrappedElementType);
auto newPtrLikeType = builder.getType(
baseType->op,
1,
&wrappedElementType);
+ addUsersToWorkList(type);
addToWorkList(newPtrLikeType);
+ addToWorkList(wrappedElementType);
- addUsersToWorkList(type);
type->replaceUsesWith(newPtrLikeType);
type->removeAndDeallocate();
return;
@@ -1911,10 +1947,13 @@ struct SpecializationContext
}
IRStructType* newStructType = nullptr;
+ addUsersToWorkList(type);
+
if( !existentialSpecializedStructs.TryGetValue(key, newStructType) )
{
builder.setInsertBefore(baseStructType);
newStructType = builder.createStructType();
+ addToWorkList(newStructType);
auto fieldSlotArgs = type->getExistentialArgs();
@@ -1939,10 +1978,8 @@ struct SpecializationContext
}
existentialSpecializedStructs.Add(key, newStructType);
- addToWorkList(newStructType);
}
- addUsersToWorkList(type);
type->replaceUsesWith(newStructType);
type->removeAndDeallocate();
return;
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 88ec5f4f7..33bdb4ef4 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -2536,8 +2536,19 @@ createStructuredBufferTypeLayout(
typeLayout->addResourceUsage(info.kind, info.size);
}
+ // If element type contains existential type params and object params,
+ // we need to propagate them through the StructuredBufferLayout.
+ if (auto existentialTypeInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam))
+ {
+ typeLayout->addResourceUsage(existentialTypeInfo->kind, existentialTypeInfo->count);
+ }
+ if (auto existentialObjInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialObjectParam))
+ {
+ typeLayout->addResourceUsage(existentialObjInfo->kind, existentialObjInfo->count);
+ }
+
// Note: for now we don't deal with the case of a structured
- // buffer that might contain anything other than "uniform" data,
+ // buffer that might contain any other resource types,
// because there really isn't a way to implement that.
return typeLayout;
diff --git a/tests/compute/interface-func-param-in-struct.slang b/tests/compute/interface-func-param-in-struct.slang
new file mode 100644
index 000000000..550758f38
--- /dev/null
+++ b/tests/compute/interface-func-param-in-struct.slang
@@ -0,0 +1,43 @@
+// Tests specializing a function with existential-struct-typed param.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu
+
+[anyValueSize(8)]
+interface IInterface
+{
+ uint eval();
+}
+
+struct Impl : IInterface
+{
+ uint val;
+ uint eval()
+ {
+ return val;
+ }
+};
+
+struct Params
+{
+ StructuredBuffer<IInterface> obj;
+};
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
+RWStructuredBuffer<uint> gOutputBuffer;
+
+void compute(uint tid, Params p)
+{
+ gOutputBuffer[tid] = p.obj[0].eval();
+}
+
+//TEST_INPUT: entryPointExistentialType Impl
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID,
+//TEST_INPUT:ubuffer(data=[0 0 0 0 1 0], stride=4):name=params.obj
+ uniform Params params)
+{
+ uint tid = dispatchThreadID.x;
+ compute(tid, params);
+} \ No newline at end of file
diff --git a/tests/compute/interface-func-param-in-struct.slang.expected.txt b/tests/compute/interface-func-param-in-struct.slang.expected.txt
new file mode 100644
index 000000000..98fb6a686
--- /dev/null
+++ b/tests/compute/interface-func-param-in-struct.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+1
+1
+1