diff options
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dispatch.cpp | 48 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp | 73 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-12.slang | 2 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-13.slang | 6 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-14.slang | 6 |
7 files changed, 125 insertions, 49 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index a0c46066c..c96286eec 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1481,21 +1481,19 @@ LinkedIR linkIR( cloneValue(context, bindInst); } } - if (target == CodeGenTarget::CPPSource || target == CodeGenTarget::CUDASource) + + for (IRModule* irModule : irModules) { - for (IRModule* irModule : irModules) + for (auto inst : irModule->getGlobalInsts()) { - for (auto inst : irModule->getGlobalInsts()) - { - auto hasPublic = inst->findDecoration<IRPublicDecoration>(); - if (!hasPublic) - continue; + auto hasPublic = inst->findDecoration<IRPublicDecoration>(); + if (!hasPublic) + continue; - auto cloned = cloneValue(context, inst); - if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) - { - context->builder->addKeepAliveDecoration(cloned); - } + auto cloned = cloneValue(context, inst); + if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) + { + context->builder->addKeepAliveDecoration(cloned); } } } diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 5f466c70c..9c852a3c1 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -17,6 +17,8 @@ namespace Slang { // Replace all uses of RTTI objects with its sequential ID. + // Currently we don't use RTTI objects at all, so all of them + // are 0. void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext) { uint32_t id = 0; @@ -26,7 +28,12 @@ namespace Slang builder.sharedBuilder = &sharedContext->sharedBuilderStorage; builder.setInsertBefore(rtti.Value); IRUse* nextUse = nullptr; - auto idOperand = builder.getIntValue(builder.getUInt64Type(), id); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + builder.getIntValue(builder.getUIntType(), id), + builder.getIntValue(builder.getUIntType(), 0)}; + auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args); for (auto use = rtti.Value->firstUse; use; use = nextUse) { nextUse = use->nextUse; @@ -38,7 +45,7 @@ namespace Slang } } - // Replace all WitnessTableID type or RTTIHandleType with uint64. + // Replace all WitnessTableID type or RTTIHandleType with `uint2`. void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext) { List<IRInst*> instsToRemove; @@ -52,7 +59,9 @@ namespace Slang IRBuilder builder; builder.sharedBuilder = &sharedContext->sharedBuilderStorage; builder.setInsertBefore(inst); - inst->replaceUsesWith(builder.getUInt64Type()); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + inst->replaceUsesWith(uint2Type); instsToRemove.add(inst); } break; @@ -99,6 +108,8 @@ namespace Slang if (sink->getErrorCount() != 0) return; + sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + specializeRTTIObjectReferences(sharedContext); cleanUpRTTIHandleTypes(sharedContext); diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index ebf3f1909..fc8f384ec 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -9,10 +9,9 @@ namespace Slang IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) { auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0); - + auto conformanceType = cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType(); // Collect all witness tables of `witnessTableType` in current module. - List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType( - cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType()); + List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(conformanceType); SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); auto block = dispatchFunc->getFirstBlock(); @@ -57,8 +56,8 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, paramTypes.add(paramInst->getFullType()); } - // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID. - paramTypes[0] = builder->getUIntType(); + // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential ID. + paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType); auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); newDispatchFunc->setFullType(newDipsatchFuncType); @@ -79,6 +78,15 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, } auto witnessTableParam = newBlock->getFirstParam(); + // `witnessTableParam` is expected to have `IRWitnessTableID` type, which + // will later lower into a `uint2`. We only use the first element of the uint2 + // to store the sequential ID and reserve the second 32-bit value for future + // pointer-compatibility. We insert a member extract inst right now + // to obtain the first element and use it in our switch statement. + UInt elemIdx = 0; + auto witnessTableSequentialID = + builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx); + // Generate case blocks for each possible witness table. List<IRInst*> caseBlocks; for (Index i = 0; i < witnessTables.getCount(); i++) @@ -115,18 +123,28 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, // Emit a switch statement to call the correct concrete function based on // the witness table sequential ID passed in. builder->setInsertInto(newDispatchFunc); - auto breakBlock = builder->emitBlock(); - builder->setInsertInto(breakBlock); - builder->emitUnreachable(); - builder->setInsertInto(newBlock); - builder->emitSwitch( - witnessTableParam, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); + + if (witnessTables.getCount() == 1) + { + // If there is only 1 case, no switch statement is necessary. + builder->setInsertInto(newBlock); + builder->emitBranch(defaultBlock); + } + else + { + auto breakBlock = builder->emitBlock(); + builder->setInsertInto(breakBlock); + builder->emitUnreachable(); + builder->setInsertInto(newBlock); + builder->emitSwitch( + witnessTableSequentialID, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + } // Remove old implementation. dispatchFunc->replaceUsesWith(newDispatchFunc); dispatchFunc->removeAndDeallocate(); diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index a8d2902f6..eb77f651e 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -37,6 +37,15 @@ struct AssociatedTypeLookupSpecializationContext auto block = builder.emitBlock(); auto witnessTableParam = builder.emitParam(inputWitnessTableIDType); + // `witnessTableParam` is expected to have `IRWitnessTableID` type, which + // will later lower into a `uint2`. We only use the first element of the uint2 + // to store the sequential ID and reserve the second 32-bit value for future + // pointer-compatibility. We insert a member extract inst right now + // to obtain the first element and use it in our switch statement. + UInt elemIdx = 0; + auto witnessTableSequentialID = + builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx); + // Collect all witness tables of `witnessTableType` in current module. List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(interfaceType); @@ -70,23 +79,41 @@ struct AssociatedTypeLookupSpecializationContext auto resultWitnessTableIDDecoration = resultWitnessTable->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(resultWitnessTableIDDecoration); - builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand()); + // Pack the resulting witness table ID into a `uint2`. + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + resultWitnessTableIDDecoration->getSequentialIDOperand(), + builder.getIntValue(builder.getUIntType(), 0)}; + auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args); + builder.emitReturn(resultID); } - // Emit a switch statement to return the correct witness table ID based on - // the witness table ID passed in. builder.setInsertInto(func); - auto breakBlock = builder.emitBlock(); - builder.setInsertInto(breakBlock); - builder.emitUnreachable(); - - builder.setInsertInto(block); - builder.emitSwitch( - witnessTableParam, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); + + if (witnessTables.getCount() == 1) + { + // If there is only 1 case, no switch statement is necessary. + builder.setInsertInto(block); + builder.emitBranch(defaultBlock); + } + else + { + // If there are more than 1 cases, + // emit a switch statement to return the correct witness table ID based on + // the witness table ID passed in. + auto breakBlock = builder.emitBlock(); + builder.setInsertInto(breakBlock); + builder.emitUnreachable(); + + builder.setInsertInto(block); + builder.emitSwitch( + witnessTableSequentialID, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + } return func; } @@ -176,12 +203,28 @@ struct AssociatedTypeLookupSpecializationContext }); // Replace all direct uses of IRWitnessTables with its sequential ID. - workOnModule([](IRInst* inst) + workOnModule([this](IRInst* inst) { if (inst->op == kIROp_WitnessTable) { auto seqId = inst->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(seqId); + // Insert code to pack sequential ID into an uint2 at all use sites. + for (auto use = inst->firstUse; use; ) + { + auto nextUse = use->nextUse; + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(use->getUser()); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + seqId->getSequentialIDOperand(), + builder.getIntValue(builder.getUIntType(), 0)}; + auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); + use->set(uint2seqID); + use = nextUse; + } inst->replaceUsesWith(seqId->getSequentialIDOperand()); } }); diff --git a/tests/compute/dynamic-dispatch-12.slang b/tests/compute/dynamic-dispatch-12.slang index cd122ec56..11bfcc1eb 100644 --- a/tests/compute/dynamic-dispatch-12.slang +++ b/tests/compute/dynamic-dispatch-12.slang @@ -1,6 +1,8 @@ // Test using interface typed shader parameters with dynamic dispatch. +//TEST(compute):COMPARE_COMPUTE:-dx11 //TEST(compute):COMPARE_COMPUTE:-cpu +//TEST(compute):COMPARE_COMPUTE:-vk //TEST(compute):COMPARE_COMPUTE:-cuda [anyValueSize(8)] diff --git a/tests/compute/dynamic-dispatch-13.slang b/tests/compute/dynamic-dispatch-13.slang index 3c6c37691..e80e5ce5f 100644 --- a/tests/compute/dynamic-dispatch-13.slang +++ b/tests/compute/dynamic-dispatch-13.slang @@ -1,6 +1,8 @@ // Test using interface typed shader parameters wrapped inside a `StructuredBuffer`. //TEST(compute):COMPARE_COMPUTE:-cpu +//TEST(compute):COMPARE_COMPUTE:-dx11 +//TEST(compute):COMPARE_COMPUTE:-vk //TEST(compute):COMPARE_COMPUTE:-cuda [anyValueSize(8)] @@ -13,10 +15,10 @@ interface IInterface RWStructuredBuffer<int> gOutputBuffer; //TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb -StructuredBuffer<IInterface> gCb; +RWStructuredBuffer<IInterface> gCb; //TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1 -StructuredBuffer<IInterface> gCb1; +RWStructuredBuffer<IInterface> gCb1; [numthreads(4, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) diff --git a/tests/compute/dynamic-dispatch-14.slang b/tests/compute/dynamic-dispatch-14.slang index 5d84a3ee6..35da4bd06 100644 --- a/tests/compute/dynamic-dispatch-14.slang +++ b/tests/compute/dynamic-dispatch-14.slang @@ -1,6 +1,8 @@ // Test using interface typed shader parameters with associated types. +//TEST(compute):COMPARE_COMPUTE:-dx11 //TEST(compute):COMPARE_COMPUTE:-cpu +//TEST(compute):COMPARE_COMPUTE:-vk //TEST(compute):COMPARE_COMPUTE:-cuda [anyValueSize(8)] @@ -20,10 +22,10 @@ interface IInterface RWStructuredBuffer<int> gOutputBuffer; //TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb -StructuredBuffer<IInterface> gCb; +RWStructuredBuffer<IInterface> gCb; //TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1 -StructuredBuffer<IInterface> gCb1; +RWStructuredBuffer<IInterface> gCb1; [numthreads(4, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) |
