diff options
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-function.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 97 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 3 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-13.slang | 45 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-13.slang.expected.txt | 4 |
10 files changed, 163 insertions, 20 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index fe35f8a19..7793dfcbc 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -74,6 +74,12 @@ namespace Slang loc); return; } + else if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type)) + { + _collectExistentialSpecializationParamsRec( + astBuilder, ioSpecializationParams, structuredBufferType->getElementType(), loc); + return; + } if( auto declRefType = as<DeclRefType>(type) ) { diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 090d06f28..8ef18242a 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -222,7 +222,7 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable) { List<IRWitnessTableEntry*> sortedWitnessTableEntries; - auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0)); + auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); auto witnessTableItems = witnessTable->getChildren(); // Build a dictionary of witness table entries for fast lookup. Dictionary<IRInst*, IRWitnessTableEntry*> witnessTableEntryDictionary; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 650b8aa8f..63cac961f 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1606,7 +1606,7 @@ void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) { - auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0)); + auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); // Ignore witness tables for builtin interface types. if (isBuiltin(interfaceType)) @@ -1634,7 +1634,7 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() { for (auto witnessTable : pendingWitnessTableDefinitions) { - auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0)); + auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable); m_writer->emit("extern \"C\"\n{\n"); m_writer->indent(); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index dc841bb3a..18641d8b7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1531,12 +1531,7 @@ struct IRWitnessTable : IRInst IRInst* getConformanceType() { - return getOperand(0); - } - - void setConformanceType(IRInst* type) - { - setOperand(0, type); + return cast<IRWitnessTableType>(getDataType())->getConformanceType(); } IR_LEAF_ISA(WitnessTable) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 1008c94a1..cc4c7e68d 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -571,7 +571,7 @@ IRWitnessTable* cloneWitnessTableImpl( IRWitnessTable* clonedTable = dstTable; if (!clonedTable) { - auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getOperand(0))); + auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType())); clonedTable = builder->createWitnessTable(clonedBaseType); } cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index c02e9e3d6..f5bd469ac 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -201,8 +201,15 @@ namespace Slang void lowerWitnessTable(IRWitnessTable* witnessTable) { auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType())); + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(witnessTable); if (interfaceType != witnessTable->getConformanceType()) - witnessTable->setConformanceType(interfaceType); + { + auto newWitnessTableType = builder->getWitnessTableType(interfaceType); + witnessTable->setFullType(newWitnessTableType); + } if (isBuiltin(interfaceType)) return; for (auto child : witnessTable->getChildren()) @@ -223,10 +230,6 @@ namespace Slang { // Translate a Type value to an RTTI object pointer. auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedContext->sharedBuilderStorage; - builder->setInsertBefore(witnessTable); auto rttiObjectPtr = builder->emitGetAddress( builder->getPtrType(builder->getRTTIType()), rttiObject); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 75b2beeec..7cbba9737 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -770,6 +770,86 @@ struct SpecializationContext } } + // Finds any `IRTargetDecoration` from `inst`. Recursively chasing `specialize` chains. + IRTargetIntrinsicDecoration* findTargetIntrinsicDecorationRec(IRInst* inst) + { + while (auto specialize = as<IRSpecialize>(inst)) + { + inst = specialize->getBase(); + } + while (auto genericInst = as<IRGeneric>(inst)) + { + inst = findGenericReturnVal(genericInst); + } + if (auto decor = inst->findDecoration<IRTargetIntrinsicDecoration>()) + return decor; + return nullptr; + } + + // Returns true if the call inst represents a call to + // StructuredBuffer::operator[]/Load/Consume methods. + bool isBufferLoadCall(IRCall* inst) + { + if (auto targetIntrinsic = findTargetIntrinsicDecorationRec(inst->getCallee())) + { + auto name = targetIntrinsic->getDefinition(); + if (name == ".operator[]" || name == ".Load" || name == ".Consume") + { + return true; + } + } + return false; + } + + /// Transform a buffer load intrinsic call. + /// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into + /// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`. + /// Returns true if `inst` matches the pattern and the load is transformed, otherwise, + /// returns false. + bool maybeSpecializeBufferLoadCall(IRCall* inst) + { + if (isBufferLoadCall(inst)) + { + SLANG_ASSERT(inst->getArgCount() > 0); + if (auto wrapExistential = as<IRWrapExistential>(inst->getArg(0))) + { + if (auto sbType = as<IRHLSLStructuredBufferTypeBase>( + wrapExistential->getWrappedValue()->getDataType())) + { + // We are seeing the instruction sequence in the form of + // .operator[](wrapExistential(structuredBuffer), idx). + // Similar to handling load(wrapExistential(..)) insts, + // we need to replace it into wrapExistential(.operator[](sb, idx)) + auto resultType = inst->getFullType(); + auto elementType = sbType->getElementType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + List<IRInst*> args; + args.add(wrapExistential->getWrappedValue()); + for (UInt i = 1; i < inst->getArgCount(); i++) + args.add(inst->getArg(i)); + List<IRInst*> slotOperands; + UInt slotOperandCount = wrapExistential->getSlotOperandCount(); + for (UInt ii = 0; ii < slotOperandCount; ++ii) + { + slotOperands.add(wrapExistential->getSlotOperand(ii)); + } + auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args); + auto newWrapExistential = builder.emitWrapExistential( + resultType, newCall, slotOperandCount, slotOperands.getBuffer()); + inst->replaceUsesWith(newWrapExistential); + inst->removeAndDeallocate(); + addUsersToWorkList(newWrapExistential); + return true; + } + } + } + return false; + } + // Given a `call` instruction in the IR, we need to detect the case // where the callee has some interface-type parameter(s) and at the // call site it is statically clear what concrete type(s) the arguments @@ -777,6 +857,12 @@ struct SpecializationContext // void maybeSpecializeExistentialsForCall(IRCall* inst) { + // Handle a special case of `StructuredBuffer.operator[]/Load/Consume` + // calls first. These calls on builtin generic types should be handled + // the same way as a `load` inst. + if (maybeSpecializeBufferLoadCall(inst)) + return; + // We can only specialize a call when the callee function is known. // auto calleeFunc = as<IRFunc>(inst->getCallee()); @@ -1678,13 +1764,18 @@ struct SpecializationContext type->removeAndDeallocate(); return; } - else if( auto basePtrLikeType = as<IRPointerLikeType>(baseType) ) + else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) ) { // A `BindExistentials<P<T>, ...>` can be simplified to // `P<BindExistentials<T, ...>>` when `P` is a pointer-like // type constructor. // - auto baseElementType = basePtrLikeType->getElementType(); + IRType* baseElementType = nullptr; + if (auto basePtrLikeType = as<IRPointerLikeType>(baseType)) + baseElementType = basePtrLikeType->getElementType(); + else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType)) + baseElementType = baseSBType->getElementType(); + IRInst* wrappedElementType = builder.getBindExistentialsType( baseElementType, slotOperandCount, @@ -1692,7 +1783,7 @@ struct SpecializationContext addToWorkList(wrappedElementType); auto newPtrLikeType = builder.getType( - basePtrLikeType->op, + baseType->op, 1, &wrappedElementType); addToWorkList(newPtrLikeType); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index bc7f7970b..28fa70edf 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2979,8 +2979,7 @@ namespace Slang IRWitnessTable* witnessTable = createInst<IRWitnessTable>( this, kIROp_WitnessTable, - getWitnessTableType(baseType), - baseType); + getWitnessTableType(baseType)); addGlobalValue(this, witnessTable); return witnessTable; } diff --git a/tests/compute/dynamic-dispatch-13.slang b/tests/compute/dynamic-dispatch-13.slang new file mode 100644 index 000000000..723f42f52 --- /dev/null +++ b/tests/compute/dynamic-dispatch-13.slang @@ -0,0 +1,45 @@ +// Test using interface typed shader parameters wrapped inside a `StructuredBuffer`. + +//TEST(compute):COMPARE_COMPUTE:-cpu +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization + +[anyValueSize(8)] +interface IInterface +{ + int run(int input); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer +RWStructuredBuffer<int> gOutputBuffer; + +//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb +StructuredBuffer<IInterface> gCb; + +//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1 +StructuredBuffer<IInterface> gCb1; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + let tid = dispatchThreadID.x; + + let inputVal : int = tid; + IInterface v0 = gCb.Load(0); + IInterface v1 = gCb1[0]; + let outputVal = v0.run(inputVal) + v1.run(inputVal); + + gOutputBuffer[tid] = outputVal; +} + +// Specialize gCb1, but not gCb2 +//TEST_INPUT: globalExistentialType MyImpl +//TEST_INPUT: globalExistentialType __Dynamic +// Type must be marked `public` to ensure it is visible in the generated DLL. +public struct MyImpl : IInterface +{ + int val; + int run(int input) + { + return input + val; + } +}; diff --git a/tests/compute/dynamic-dispatch-13.slang.expected.txt b/tests/compute/dynamic-dispatch-13.slang.expected.txt new file mode 100644 index 000000000..628d2a8fd --- /dev/null +++ b/tests/compute/dynamic-dispatch-13.slang.expected.txt @@ -0,0 +1,4 @@ +2 +4 +6 +8 |
