summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-shader.cpp6
-rw-r--r--source/slang/slang-emit-c-like.cpp2
-rw-r--r--source/slang/slang-emit-cpp.cpp4
-rw-r--r--source/slang/slang-ir-insts.h7
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp13
-rw-r--r--source/slang/slang-ir-specialize.cpp97
-rw-r--r--source/slang/slang-ir.cpp3
-rw-r--r--tests/compute/dynamic-dispatch-13.slang45
-rw-r--r--tests/compute/dynamic-dispatch-13.slang.expected.txt4
10 files changed, 163 insertions, 20 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index fe35f8a19..7793dfcbc 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -74,6 +74,12 @@ namespace Slang
loc);
return;
}
+ else if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type))
+ {
+ _collectExistentialSpecializationParamsRec(
+ astBuilder, ioSpecializationParams, structuredBufferType->getElementType(), loc);
+ return;
+ }
if( auto declRefType = as<DeclRefType>(type) )
{
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 090d06f28..8ef18242a 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -222,7 +222,7 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable)
{
List<IRWitnessTableEntry*> sortedWitnessTableEntries;
- auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
+ auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
auto witnessTableItems = witnessTable->getChildren();
// Build a dictionary of witness table entries for fast lookup.
Dictionary<IRInst*, IRWitnessTableEntry*> witnessTableEntryDictionary;
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 650b8aa8f..63cac961f 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1606,7 +1606,7 @@ void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
{
- auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
+ auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
// Ignore witness tables for builtin interface types.
if (isBuiltin(interfaceType))
@@ -1634,7 +1634,7 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
{
for (auto witnessTable : pendingWitnessTableDefinitions)
{
- auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
+ auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
m_writer->emit("extern \"C\"\n{\n");
m_writer->indent();
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index dc841bb3a..18641d8b7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1531,12 +1531,7 @@ struct IRWitnessTable : IRInst
IRInst* getConformanceType()
{
- return getOperand(0);
- }
-
- void setConformanceType(IRInst* type)
- {
- setOperand(0, type);
+ return cast<IRWitnessTableType>(getDataType())->getConformanceType();
}
IR_LEAF_ISA(WitnessTable)
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 1008c94a1..cc4c7e68d 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -571,7 +571,7 @@ IRWitnessTable* cloneWitnessTableImpl(
IRWitnessTable* clonedTable = dstTable;
if (!clonedTable)
{
- auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getOperand(0)));
+ auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType()));
clonedTable = builder->createWitnessTable(clonedBaseType);
}
cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index c02e9e3d6..f5bd469ac 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -201,8 +201,15 @@ namespace Slang
void lowerWitnessTable(IRWitnessTable* witnessTable)
{
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType()));
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(witnessTable);
if (interfaceType != witnessTable->getConformanceType())
- witnessTable->setConformanceType(interfaceType);
+ {
+ auto newWitnessTableType = builder->getWitnessTableType(interfaceType);
+ witnessTable->setFullType(newWitnessTableType);
+ }
if (isBuiltin(interfaceType))
return;
for (auto child : witnessTable->getChildren())
@@ -223,10 +230,6 @@ namespace Slang
{
// Translate a Type value to an RTTI object pointer.
auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal());
- IRBuilder builderStorage;
- auto builder = &builderStorage;
- builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
- builder->setInsertBefore(witnessTable);
auto rttiObjectPtr = builder->emitGetAddress(
builder->getPtrType(builder->getRTTIType()),
rttiObject);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 75b2beeec..7cbba9737 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -770,6 +770,86 @@ struct SpecializationContext
}
}
+ // Finds any `IRTargetDecoration` from `inst`. Recursively chasing `specialize` chains.
+ IRTargetIntrinsicDecoration* findTargetIntrinsicDecorationRec(IRInst* inst)
+ {
+ while (auto specialize = as<IRSpecialize>(inst))
+ {
+ inst = specialize->getBase();
+ }
+ while (auto genericInst = as<IRGeneric>(inst))
+ {
+ inst = findGenericReturnVal(genericInst);
+ }
+ if (auto decor = inst->findDecoration<IRTargetIntrinsicDecoration>())
+ return decor;
+ return nullptr;
+ }
+
+ // Returns true if the call inst represents a call to
+ // StructuredBuffer::operator[]/Load/Consume methods.
+ bool isBufferLoadCall(IRCall* inst)
+ {
+ if (auto targetIntrinsic = findTargetIntrinsicDecorationRec(inst->getCallee()))
+ {
+ auto name = targetIntrinsic->getDefinition();
+ if (name == ".operator[]" || name == ".Load" || name == ".Consume")
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Transform a buffer load intrinsic call.
+ /// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into
+ /// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`.
+ /// Returns true if `inst` matches the pattern and the load is transformed, otherwise,
+ /// returns false.
+ bool maybeSpecializeBufferLoadCall(IRCall* inst)
+ {
+ if (isBufferLoadCall(inst))
+ {
+ SLANG_ASSERT(inst->getArgCount() > 0);
+ if (auto wrapExistential = as<IRWrapExistential>(inst->getArg(0)))
+ {
+ if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(
+ wrapExistential->getWrappedValue()->getDataType()))
+ {
+ // We are seeing the instruction sequence in the form of
+ // .operator[](wrapExistential(structuredBuffer), idx).
+ // Similar to handling load(wrapExistential(..)) insts,
+ // we need to replace it into wrapExistential(.operator[](sb, idx))
+ auto resultType = inst->getFullType();
+ auto elementType = sbType->getElementType();
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(inst);
+
+ List<IRInst*> args;
+ args.add(wrapExistential->getWrappedValue());
+ for (UInt i = 1; i < inst->getArgCount(); i++)
+ args.add(inst->getArg(i));
+ List<IRInst*> slotOperands;
+ UInt slotOperandCount = wrapExistential->getSlotOperandCount();
+ for (UInt ii = 0; ii < slotOperandCount; ++ii)
+ {
+ slotOperands.add(wrapExistential->getSlotOperand(ii));
+ }
+ auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args);
+ auto newWrapExistential = builder.emitWrapExistential(
+ resultType, newCall, slotOperandCount, slotOperands.getBuffer());
+ inst->replaceUsesWith(newWrapExistential);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newWrapExistential);
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
// Given a `call` instruction in the IR, we need to detect the case
// where the callee has some interface-type parameter(s) and at the
// call site it is statically clear what concrete type(s) the arguments
@@ -777,6 +857,12 @@ struct SpecializationContext
//
void maybeSpecializeExistentialsForCall(IRCall* inst)
{
+ // Handle a special case of `StructuredBuffer.operator[]/Load/Consume`
+ // calls first. These calls on builtin generic types should be handled
+ // the same way as a `load` inst.
+ if (maybeSpecializeBufferLoadCall(inst))
+ return;
+
// We can only specialize a call when the callee function is known.
//
auto calleeFunc = as<IRFunc>(inst->getCallee());
@@ -1678,13 +1764,18 @@ struct SpecializationContext
type->removeAndDeallocate();
return;
}
- else if( auto basePtrLikeType = as<IRPointerLikeType>(baseType) )
+ else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) )
{
// A `BindExistentials<P<T>, ...>` can be simplified to
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
// type constructor.
//
- auto baseElementType = basePtrLikeType->getElementType();
+ IRType* baseElementType = nullptr;
+ if (auto basePtrLikeType = as<IRPointerLikeType>(baseType))
+ baseElementType = basePtrLikeType->getElementType();
+ else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
+ baseElementType = baseSBType->getElementType();
+
IRInst* wrappedElementType = builder.getBindExistentialsType(
baseElementType,
slotOperandCount,
@@ -1692,7 +1783,7 @@ struct SpecializationContext
addToWorkList(wrappedElementType);
auto newPtrLikeType = builder.getType(
- basePtrLikeType->op,
+ baseType->op,
1,
&wrappedElementType);
addToWorkList(newPtrLikeType);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index bc7f7970b..28fa70edf 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2979,8 +2979,7 @@ namespace Slang
IRWitnessTable* witnessTable = createInst<IRWitnessTable>(
this,
kIROp_WitnessTable,
- getWitnessTableType(baseType),
- baseType);
+ getWitnessTableType(baseType));
addGlobalValue(this, witnessTable);
return witnessTable;
}
diff --git a/tests/compute/dynamic-dispatch-13.slang b/tests/compute/dynamic-dispatch-13.slang
new file mode 100644
index 000000000..723f42f52
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-13.slang
@@ -0,0 +1,45 @@
+// Test using interface typed shader parameters wrapped inside a `StructuredBuffer`.
+
+//TEST(compute):COMPARE_COMPUTE:-cpu
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization
+
+[anyValueSize(8)]
+interface IInterface
+{
+ int run(int input);
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
+RWStructuredBuffer<int> gOutputBuffer;
+
+//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
+StructuredBuffer<IInterface> gCb;
+
+//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1
+StructuredBuffer<IInterface> gCb1;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ let tid = dispatchThreadID.x;
+
+ let inputVal : int = tid;
+ IInterface v0 = gCb.Load(0);
+ IInterface v1 = gCb1[0];
+ let outputVal = v0.run(inputVal) + v1.run(inputVal);
+
+ gOutputBuffer[tid] = outputVal;
+}
+
+// Specialize gCb1, but not gCb2
+//TEST_INPUT: globalExistentialType MyImpl
+//TEST_INPUT: globalExistentialType __Dynamic
+// Type must be marked `public` to ensure it is visible in the generated DLL.
+public struct MyImpl : IInterface
+{
+ int val;
+ int run(int input)
+ {
+ return input + val;
+ }
+};
diff --git a/tests/compute/dynamic-dispatch-13.slang.expected.txt b/tests/compute/dynamic-dispatch-13.slang.expected.txt
new file mode 100644
index 000000000..628d2a8fd
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-13.slang.expected.txt
@@ -0,0 +1,4 @@
+2
+4
+6
+8