diff options
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 49 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 13 | ||||
| -rw-r--r-- | tests/compute/interface-func-param-in-struct.slang | 43 | ||||
| -rw-r--r-- | tests/compute/interface-func-param-in-struct.slang.expected.txt | 4 |
4 files changed, 102 insertions, 7 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index d364de3fb..ceba4b03c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -808,6 +808,24 @@ struct SpecializationContext return false; } + /// Used by `maybeSpecailizeBufferLoadCall`, this function returns a new specialized callee that + /// replaces a `specialize(.operator[], oldType)` to `specialize(.operator[], newElementType)`. + IRInst* getNewSpecializedBufferLoadCallee( + IRInst* oldSpecializedCallee, + IRType* newContainerType, + IRType* newElementType) + { + auto oldSpecialize = cast<IRSpecialize>(oldSpecializedCallee); + SLANG_ASSERT(oldSpecialize->getArgCount() == 1); + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(oldSpecializedCallee); + auto calleeType = builder.getFuncType(1, &newContainerType, newElementType); + auto newSpecialize = builder.emitSpecializeInst( + calleeType, oldSpecialize->getBase(), 1, (IRInst**)&newElementType); + return newSpecialize; + } + /// Transform a buffer load intrinsic call. /// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into /// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`. @@ -844,11 +862,18 @@ struct SpecializationContext { slotOperands.add(wrapExistential->getSlotOperand(ii)); } - auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args); + // The old callee should be in the form of `specialize(.operator[], IInterfaceType)`, + // we should update it to be `specialize(.operator[], elementType)`, so the return type + // of the load call is `elementType`. + auto oldCallee = inst->getCallee(); + auto newCallee = getNewSpecializedBufferLoadCallee(inst->getCallee(), sbType, elementType); + auto newCall = builder.emitCallInst(elementType, newCallee, args); auto newWrapExistential = builder.emitWrapExistential( resultType, newCall, slotOperandCount, slotOperands.getBuffer()); inst->replaceUsesWith(newWrapExistential); inst->removeAndDeallocate(); + SLANG_ASSERT(!oldCallee->hasUses()); + oldCallee->removeAndDeallocate(); addUsersToWorkList(newWrapExistential); return true; } @@ -1080,7 +1105,8 @@ struct SpecializationContext // if(as<IRInterfaceType>(type)) return true; - + if (calcExistentialTypeParamSlotCount(type) != 0) + return true; // Eventually we will also want to handle arrays over // existential types, but that will require careful // handling in many places. @@ -1518,6 +1544,11 @@ struct SpecializationContext type = arrayType->getElementType(); goto top; } + else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type)) + { + type = sbType->getElementType(); + goto top; + } else if( auto structType = as<IRStructType>(type) ) { UInt count = 0; @@ -1800,6 +1831,11 @@ struct SpecializationContext type = ptrLikeType->getElementType(); goto top; } + else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type)) + { + type = sbType->getElementType(); + goto top; + } else if( auto structType = as<IRStructType>(type) ) { UInt count = 0; @@ -1872,15 +1908,15 @@ struct SpecializationContext baseElementType, slotOperandCount, type->getExistentialArgs()); - addToWorkList(wrappedElementType); auto newPtrLikeType = builder.getType( baseType->op, 1, &wrappedElementType); + addUsersToWorkList(type); addToWorkList(newPtrLikeType); + addToWorkList(wrappedElementType); - addUsersToWorkList(type); type->replaceUsesWith(newPtrLikeType); type->removeAndDeallocate(); return; @@ -1911,10 +1947,13 @@ struct SpecializationContext } IRStructType* newStructType = nullptr; + addUsersToWorkList(type); + if( !existentialSpecializedStructs.TryGetValue(key, newStructType) ) { builder.setInsertBefore(baseStructType); newStructType = builder.createStructType(); + addToWorkList(newStructType); auto fieldSlotArgs = type->getExistentialArgs(); @@ -1939,10 +1978,8 @@ struct SpecializationContext } existentialSpecializedStructs.Add(key, newStructType); - addToWorkList(newStructType); } - addUsersToWorkList(type); type->replaceUsesWith(newStructType); type->removeAndDeallocate(); return; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 88ec5f4f7..33bdb4ef4 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -2536,8 +2536,19 @@ createStructuredBufferTypeLayout( typeLayout->addResourceUsage(info.kind, info.size); } + // If element type contains existential type params and object params, + // we need to propagate them through the StructuredBufferLayout. + if (auto existentialTypeInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) + { + typeLayout->addResourceUsage(existentialTypeInfo->kind, existentialTypeInfo->count); + } + if (auto existentialObjInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialObjectParam)) + { + typeLayout->addResourceUsage(existentialObjInfo->kind, existentialObjInfo->count); + } + // Note: for now we don't deal with the case of a structured - // buffer that might contain anything other than "uniform" data, + // buffer that might contain any other resource types, // because there really isn't a way to implement that. return typeLayout; diff --git a/tests/compute/interface-func-param-in-struct.slang b/tests/compute/interface-func-param-in-struct.slang new file mode 100644 index 000000000..550758f38 --- /dev/null +++ b/tests/compute/interface-func-param-in-struct.slang @@ -0,0 +1,43 @@ +// Tests specializing a function with existential-struct-typed param. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu + +[anyValueSize(8)] +interface IInterface +{ + uint eval(); +} + +struct Impl : IInterface +{ + uint val; + uint eval() + { + return val; + } +}; + +struct Params +{ + StructuredBuffer<IInterface> obj; +}; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer +RWStructuredBuffer<uint> gOutputBuffer; + +void compute(uint tid, Params p) +{ + gOutputBuffer[tid] = p.obj[0].eval(); +} + +//TEST_INPUT: entryPointExistentialType Impl + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID, +//TEST_INPUT:ubuffer(data=[0 0 0 0 1 0], stride=4):name=params.obj + uniform Params params) +{ + uint tid = dispatchThreadID.x; + compute(tid, params); +}
\ No newline at end of file diff --git a/tests/compute/interface-func-param-in-struct.slang.expected.txt b/tests/compute/interface-func-param-in-struct.slang.expected.txt new file mode 100644 index 000000000..98fb6a686 --- /dev/null +++ b/tests/compute/interface-func-param-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +1 +1 +1 +1 |
