summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-link.cpp22
-rw-r--r--source/slang/slang-ir-lower-generics.cpp17
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp48
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp73
-rw-r--r--tests/compute/dynamic-dispatch-12.slang2
-rw-r--r--tests/compute/dynamic-dispatch-13.slang6
-rw-r--r--tests/compute/dynamic-dispatch-14.slang6
7 files changed, 125 insertions, 49 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index a0c46066c..c96286eec 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -1481,21 +1481,19 @@ LinkedIR linkIR(
cloneValue(context, bindInst);
}
}
- if (target == CodeGenTarget::CPPSource || target == CodeGenTarget::CUDASource)
+
+ for (IRModule* irModule : irModules)
{
- for (IRModule* irModule : irModules)
+ for (auto inst : irModule->getGlobalInsts())
{
- for (auto inst : irModule->getGlobalInsts())
- {
- auto hasPublic = inst->findDecoration<IRPublicDecoration>();
- if (!hasPublic)
- continue;
+ auto hasPublic = inst->findDecoration<IRPublicDecoration>();
+ if (!hasPublic)
+ continue;
- auto cloned = cloneValue(context, inst);
- if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
- {
- context->builder->addKeepAliveDecoration(cloned);
- }
+ auto cloned = cloneValue(context, inst);
+ if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
+ {
+ context->builder->addKeepAliveDecoration(cloned);
}
}
}
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 5f466c70c..9c852a3c1 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -17,6 +17,8 @@
namespace Slang
{
// Replace all uses of RTTI objects with its sequential ID.
+ // Currently we don't use RTTI objects at all, so all of them
+ // are 0.
void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext)
{
uint32_t id = 0;
@@ -26,7 +28,12 @@ namespace Slang
builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
builder.setInsertBefore(rtti.Value);
IRUse* nextUse = nullptr;
- auto idOperand = builder.getIntValue(builder.getUInt64Type(), id);
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ IRInst* uint2Args[] = {
+ builder.getIntValue(builder.getUIntType(), id),
+ builder.getIntValue(builder.getUIntType(), 0)};
+ auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args);
for (auto use = rtti.Value->firstUse; use; use = nextUse)
{
nextUse = use->nextUse;
@@ -38,7 +45,7 @@ namespace Slang
}
}
- // Replace all WitnessTableID type or RTTIHandleType with uint64.
+ // Replace all WitnessTableID type or RTTIHandleType with `uint2`.
void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext)
{
List<IRInst*> instsToRemove;
@@ -52,7 +59,9 @@ namespace Slang
IRBuilder builder;
builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
builder.setInsertBefore(inst);
- inst->replaceUsesWith(builder.getUInt64Type());
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ inst->replaceUsesWith(uint2Type);
instsToRemove.add(inst);
}
break;
@@ -99,6 +108,8 @@ namespace Slang
if (sink->getErrorCount() != 0)
return;
+ sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+
specializeRTTIObjectReferences(sharedContext);
cleanUpRTTIHandleTypes(sharedContext);
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index ebf3f1909..fc8f384ec 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -9,10 +9,9 @@ namespace Slang
IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
-
+ auto conformanceType = cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType();
// Collect all witness tables of `witnessTableType` in current module.
- List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(
- cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType());
+ List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(conformanceType);
SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
auto block = dispatchFunc->getFirstBlock();
@@ -57,8 +56,8 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
paramTypes.add(paramInst->getFullType());
}
- // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID.
- paramTypes[0] = builder->getUIntType();
+ // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential ID.
+ paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType);
auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType());
newDispatchFunc->setFullType(newDipsatchFuncType);
@@ -79,6 +78,15 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
}
auto witnessTableParam = newBlock->getFirstParam();
+ // `witnessTableParam` is expected to have `IRWitnessTableID` type, which
+ // will later lower into a `uint2`. We only use the first element of the uint2
+ // to store the sequential ID and reserve the second 32-bit value for future
+ // pointer-compatibility. We insert a member extract inst right now
+ // to obtain the first element and use it in our switch statement.
+ UInt elemIdx = 0;
+ auto witnessTableSequentialID =
+ builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx);
+
// Generate case blocks for each possible witness table.
List<IRInst*> caseBlocks;
for (Index i = 0; i < witnessTables.getCount(); i++)
@@ -115,18 +123,28 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
// Emit a switch statement to call the correct concrete function based on
// the witness table sequential ID passed in.
builder->setInsertInto(newDispatchFunc);
- auto breakBlock = builder->emitBlock();
- builder->setInsertInto(breakBlock);
- builder->emitUnreachable();
- builder->setInsertInto(newBlock);
- builder->emitSwitch(
- witnessTableParam,
- breakBlock,
- defaultBlock,
- caseBlocks.getCount(),
- caseBlocks.getBuffer());
+
+ if (witnessTables.getCount() == 1)
+ {
+ // If there is only 1 case, no switch statement is necessary.
+ builder->setInsertInto(newBlock);
+ builder->emitBranch(defaultBlock);
+ }
+ else
+ {
+ auto breakBlock = builder->emitBlock();
+ builder->setInsertInto(breakBlock);
+ builder->emitUnreachable();
+ builder->setInsertInto(newBlock);
+ builder->emitSwitch(
+ witnessTableSequentialID,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+ }
// Remove old implementation.
dispatchFunc->replaceUsesWith(newDispatchFunc);
dispatchFunc->removeAndDeallocate();
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index a8d2902f6..eb77f651e 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -37,6 +37,15 @@ struct AssociatedTypeLookupSpecializationContext
auto block = builder.emitBlock();
auto witnessTableParam = builder.emitParam(inputWitnessTableIDType);
+ // `witnessTableParam` is expected to have `IRWitnessTableID` type, which
+ // will later lower into a `uint2`. We only use the first element of the uint2
+ // to store the sequential ID and reserve the second 32-bit value for future
+ // pointer-compatibility. We insert a member extract inst right now
+ // to obtain the first element and use it in our switch statement.
+ UInt elemIdx = 0;
+ auto witnessTableSequentialID =
+ builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx);
+
// Collect all witness tables of `witnessTableType` in current module.
List<IRWitnessTable*> witnessTables =
sharedContext->getWitnessTablesFromInterfaceType(interfaceType);
@@ -70,23 +79,41 @@ struct AssociatedTypeLookupSpecializationContext
auto resultWitnessTableIDDecoration =
resultWitnessTable->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(resultWitnessTableIDDecoration);
- builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand());
+ // Pack the resulting witness table ID into a `uint2`.
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ IRInst* uint2Args[] = {
+ resultWitnessTableIDDecoration->getSequentialIDOperand(),
+ builder.getIntValue(builder.getUIntType(), 0)};
+ auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args);
+ builder.emitReturn(resultID);
}
- // Emit a switch statement to return the correct witness table ID based on
- // the witness table ID passed in.
builder.setInsertInto(func);
- auto breakBlock = builder.emitBlock();
- builder.setInsertInto(breakBlock);
- builder.emitUnreachable();
-
- builder.setInsertInto(block);
- builder.emitSwitch(
- witnessTableParam,
- breakBlock,
- defaultBlock,
- caseBlocks.getCount(),
- caseBlocks.getBuffer());
+
+ if (witnessTables.getCount() == 1)
+ {
+ // If there is only 1 case, no switch statement is necessary.
+ builder.setInsertInto(block);
+ builder.emitBranch(defaultBlock);
+ }
+ else
+ {
+ // If there are more than 1 cases,
+ // emit a switch statement to return the correct witness table ID based on
+ // the witness table ID passed in.
+ auto breakBlock = builder.emitBlock();
+ builder.setInsertInto(breakBlock);
+ builder.emitUnreachable();
+
+ builder.setInsertInto(block);
+ builder.emitSwitch(
+ witnessTableSequentialID,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+ }
return func;
}
@@ -176,12 +203,28 @@ struct AssociatedTypeLookupSpecializationContext
});
// Replace all direct uses of IRWitnessTables with its sequential ID.
- workOnModule([](IRInst* inst)
+ workOnModule([this](IRInst* inst)
{
if (inst->op == kIROp_WitnessTable)
{
auto seqId = inst->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(seqId);
+ // Insert code to pack sequential ID into an uint2 at all use sites.
+ for (auto use = inst->firstUse; use; )
+ {
+ auto nextUse = use->nextUse;
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder.setInsertBefore(use->getUser());
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ IRInst* uint2Args[] = {
+ seqId->getSequentialIDOperand(),
+ builder.getIntValue(builder.getUIntType(), 0)};
+ auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args);
+ use->set(uint2seqID);
+ use = nextUse;
+ }
inst->replaceUsesWith(seqId->getSequentialIDOperand());
}
});
diff --git a/tests/compute/dynamic-dispatch-12.slang b/tests/compute/dynamic-dispatch-12.slang
index cd122ec56..11bfcc1eb 100644
--- a/tests/compute/dynamic-dispatch-12.slang
+++ b/tests/compute/dynamic-dispatch-12.slang
@@ -1,6 +1,8 @@
// Test using interface typed shader parameters with dynamic dispatch.
+//TEST(compute):COMPARE_COMPUTE:-dx11
//TEST(compute):COMPARE_COMPUTE:-cpu
+//TEST(compute):COMPARE_COMPUTE:-vk
//TEST(compute):COMPARE_COMPUTE:-cuda
[anyValueSize(8)]
diff --git a/tests/compute/dynamic-dispatch-13.slang b/tests/compute/dynamic-dispatch-13.slang
index 3c6c37691..e80e5ce5f 100644
--- a/tests/compute/dynamic-dispatch-13.slang
+++ b/tests/compute/dynamic-dispatch-13.slang
@@ -1,6 +1,8 @@
// Test using interface typed shader parameters wrapped inside a `StructuredBuffer`.
//TEST(compute):COMPARE_COMPUTE:-cpu
+//TEST(compute):COMPARE_COMPUTE:-dx11
+//TEST(compute):COMPARE_COMPUTE:-vk
//TEST(compute):COMPARE_COMPUTE:-cuda
[anyValueSize(8)]
@@ -13,10 +15,10 @@ interface IInterface
RWStructuredBuffer<int> gOutputBuffer;
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
-StructuredBuffer<IInterface> gCb;
+RWStructuredBuffer<IInterface> gCb;
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1
-StructuredBuffer<IInterface> gCb1;
+RWStructuredBuffer<IInterface> gCb1;
[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
diff --git a/tests/compute/dynamic-dispatch-14.slang b/tests/compute/dynamic-dispatch-14.slang
index 5d84a3ee6..35da4bd06 100644
--- a/tests/compute/dynamic-dispatch-14.slang
+++ b/tests/compute/dynamic-dispatch-14.slang
@@ -1,6 +1,8 @@
// Test using interface typed shader parameters with associated types.
+//TEST(compute):COMPARE_COMPUTE:-dx11
//TEST(compute):COMPARE_COMPUTE:-cpu
+//TEST(compute):COMPARE_COMPUTE:-vk
//TEST(compute):COMPARE_COMPUTE:-cuda
[anyValueSize(8)]
@@ -20,10 +22,10 @@ interface IInterface
RWStructuredBuffer<int> gOutputBuffer;
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
-StructuredBuffer<IInterface> gCb;
+RWStructuredBuffer<IInterface> gCb;
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1
-StructuredBuffer<IInterface> gCb1;
+RWStructuredBuffer<IInterface> gCb1;
[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)