summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-06-24 13:16:11 -0700
committerYong He <yonghe@outlook.com>2020-06-24 18:10:15 -0700
commit0ca75fe002f346f6ab9b77f40c0576d2905560f1 (patch)
treeed8a3af372900923e59f0d6da629c2d0969ee7fd
parent3fe4f5398d524333e955ecb91be5646e86f3b2da (diff)
Dynamic dispatch for generic interface requirements.
-Lower interfaces into actual `IRInterfaceType` insts. -Lower `DeclRef<AssocTypeDecl>` into `IRAssociatedType` -Generate proper IRType for generic functions. -Add a test case exercising dynamic dispatching a generic static function through an associated type. -Bug fixes for the test case.
-rw-r--r--source/slang/slang-emit-c-like.cpp21
-rw-r--r--source/slang/slang-emit-cpp.cpp64
-rw-r--r--source/slang/slang-emit-cpp.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h14
-rw-r--r--source/slang/slang-ir-insts.h25
-rw-r--r--source/slang/slang-ir-link.cpp5
-rw-r--r--source/slang/slang-ir-lower-generics.cpp181
-rw-r--r--source/slang/slang-ir.cpp32
-rw-r--r--source/slang/slang-ir.h19
-rw-r--r--source/slang/slang-lower-to-ir.cpp336
-rw-r--r--tests/compute/dynamic-dispatch-3.slang60
-rw-r--r--tests/compute/dynamic-dispatch-3.slang.expected.txt4
12 files changed, 593 insertions, 170 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 3438fd3f4..2605723c7 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -236,9 +236,9 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi
// Get a sorted list of entries using RequirementKeys defined in `interfaceType`.
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
{
- auto reqKey = cast<IRStructKey>(interfaceType->getOperand(i));
+ auto reqEntry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
IRWitnessTableEntry* entry = nullptr;
- if (witnessTableEntryDictionary.TryGetValue(reqKey, entry))
+ if (witnessTableEntryDictionary.TryGetValue(reqEntry->getRequirementKey(), entry))
{
sortedWitnessTableEntries.add(entry);
}
@@ -1962,6 +1962,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
are hashed with 'getStringHash' */
break;
+ case kIROp_undefined:
+ m_writer->emit(getName(inst));
+ break;
+
case kIROp_IntLit:
case kIROp_FloatLit:
case kIROp_BoolLit:
@@ -3554,6 +3558,11 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst)
are hashed with 'getStringHash' */
break;
+ case kIROp_InterfaceRequirementEntry:
+ // Don't emit anything for interface requirement at global level.
+ // They are handled in `emitInterface`.
+ break;
+
case kIROp_Func:
emitFunc((IRFunc*) inst);
break;
@@ -3610,6 +3619,10 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
ensureInstOperand(ctx, inst->getFullType());
UInt operandCount = inst->operandCount;
+ auto requiredLevel = EmitAction::Definition;
+ if (inst->op == kIROp_InterfaceType)
+ requiredLevel = EmitAction::ForwardDeclaration;
+
for(UInt ii = 0; ii < operandCount; ++ii)
{
// TODO: there are some special cases we can add here,
@@ -3620,8 +3633,8 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
// only need the type they point to to be forward-declared.
// Similarly, a `call` instruction only needs the callee
// to be forward-declared, etc.
-
- ensureInstOperand(ctx, inst->getOperand(ii));
+
+ ensureInstOperand(ctx, inst->getOperand(ii), requiredLevel);
}
for(auto child : inst->getDecorationsAndChildren())
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 4a59f4cf9..eeace4aa7 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -390,12 +390,27 @@ static UnownedStringSlice _getResourceTypePrefix(IROp op)
}
}
+static bool isVoidPtrType(IRType* type)
+{
+ auto ptrType = as<IRPtrType>(type);
+ if (!ptrType) return false;
+ return ptrType->getValueType()->op == kIROp_VoidType;
+}
+
SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out)
{
switch (type->op)
{
case kIROp_PtrType:
{
+ if (isVoidPtrType(type))
+ {
+ // A `void*` type will always emit as `void*`.
+ // `void*` types are generated as a result of generics lowering
+ // for dynamic dispatch.
+ out << "void*";
+ return SLANG_OK;
+ }
auto ptrType = static_cast<IRPtrType*>(type);
SLANG_RETURN_ON_FAIL(calcTypeName(ptrType->getValueType(), target, out));
// TODO(JS): It seems although it says it is a pointer, it can actually be output as a reference
@@ -494,7 +509,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
// struct of function pointers corresponding to the interface type.
auto witnessTableType = static_cast<IRWitnessTableType*>(type);
auto baseType = cast<IRType>(witnessTableType->getOperand(0));
- emitType(baseType);
+ SLANG_RETURN_ON_FAIL(calcTypeName(baseType, target, out));
out << "*";
return SLANG_OK;
}
@@ -1591,8 +1606,7 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
{
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
auto witnessTableItems = witnessTable->getChildren();
- List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
- _maybeEmitWitnessTableTypeDefinition(interfaceType, sortedWitnessTableEntries);
+ _maybeEmitWitnessTableTypeDefinition(interfaceType);
// Define a global variable for the witness table.
m_writer->emit("extern ");
@@ -1747,17 +1761,16 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType)
/// acoording to the order defined by `interfaceType`.
///
void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
- IRInterfaceType* interfaceType,
- const List<IRWitnessTableEntry*>& sortedWitnessTableEntries)
+ IRInterfaceType* interfaceType)
{
m_writer->emit("struct ");
emitSimpleType(interfaceType);
m_writer->emit("\n{\n");
m_writer->indent();
- for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++)
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
{
- auto entry = sortedWitnessTableEntries[i];
- if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get()))
+ auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
+ if (auto funcVal = as<IRFuncType>(entry->getRequirementVal()))
{
emitType(funcVal->getResultType());
m_writer->emit(" (KernelContext::*");
@@ -1765,33 +1778,35 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
m_writer->emit(")");
m_writer->emit("(");
bool isFirstParam = true;
- for (auto param : funcVal->getParams())
+ for (UInt p = 0; p < funcVal->getParamCount(); p++)
{
+ auto paramType = funcVal->getParamType(p);
+ // Ingore TypeType-typed parameters for now.
+ if (as<IRTypeType>(paramType))
+ continue;
+
if (!isFirstParam)
m_writer->emit(", ");
else
isFirstParam = false;
- if (param->findDecoration<IRThisPointerDecoration>())
+ auto thisDecor = funcVal->findDecoration<IRThisPointerDecoration>();
+ if (thisDecor && cast<IRIntLit>(thisDecor->getOperand(0))->value.intVal == (IRIntegerValue)p)
{
- m_writer->emit("void* ");
- m_writer->emit(getName(param));
+ m_writer->emit("void* param");
+ m_writer->emit(p);
continue;
}
- emitSimpleFuncParamImpl(param);
+ emitParamType(paramType, String("param") + String(p));
}
m_writer->emit(");\n");
}
- else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
+ else if (auto constraintInterfaceType = as<IRInterfaceType>(entry->getRequirementVal()))
{
- emitType(as<IRType>(witnessTableVal->getOperand(0)));
+ emitType(constraintInterfaceType);
m_writer->emit("* ");
m_writer->emit(getName(entry->requirementKey.get()));
m_writer->emit(";\n");
}
- else
- {
- // TODO: handle other witness table entry types.
- }
}
m_writer->dedent();
m_writer->emit("};\n");
@@ -1990,13 +2005,6 @@ void CPPSourceEmitter::emitSimpleValueImpl(IRInst* inst)
}
}
-static bool isVoidPtrType(IRType* type)
-{
- auto ptrType = as<IRPtrType>(type);
- if (!ptrType) return false;
- return ptrType->getValueType()->op == kIROp_VoidType;
-}
-
void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
{
// Polymorphic types are already translated to void* type in
@@ -2004,9 +2012,7 @@ void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
// emit "void&" instead of "void*" for pointer types.
// In the future, we will handle pointer types more properly,
// and this override logic will not be necessary.
- // For now we special-case this scenario.
- if (param->findDecoration<IRPolymorphicDecoration>() &&
- isVoidPtrType(param->getDataType()))
+ if (isVoidPtrType(param->getDataType()))
{
m_writer->emit("void* ");
m_writer->emit(getName(param));
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 47ba03d70..6f91444a3 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -89,7 +89,7 @@ protected:
virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder);
// Emits a struct of function pointers defined in `interfaceType`.
- void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType, const List<IRWitnessTableEntry*>& sortedWitnessTableEntries);
+ void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType);
void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
void _emitForwardDeclarations(const List<EmitAction>& actions);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 58ff1a79f..e9bc23993 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -164,7 +164,8 @@ INST(Nop, nop, 0, 0)
// `field` instructions.
//
INST(StructType, struct, 0, PARENT)
-INST(InterfaceType, interface, 0, PARENT)
+INST(InterfaceType, interface, 0, 0)
+INST(AssociatedType, associated_type, 0, 0)
// A TypeType-typed IRValue represents a IRType.
// It is used to represent a type parameter/argument in a generics.
@@ -223,6 +224,7 @@ INST(Call, call, 1, 0)
INST(WitnessTableEntry, witness_table_entry, 2, 0)
+INST(InterfaceRequirementEntry, interface_req_entry, 2, 0)
INST(Param, param, 0, 0)
INST(StructField, field, 2, 0)
@@ -507,14 +509,12 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0)
- /// A `[polymorphic]` decoration marks a function parameter that should translate to an abstract type
- /// e.g. (void*) that are casted to actual type before use. For example, a parameter of generic type
- /// is marked `[polymorphic]`, so that the code gen logic can emit it as a `void*` parameter,
- /// allowing the function to be used at sites that are agnostic of the actual object type.
- INST(PolymorphicDecoration, polymorphic, 0, 0)
/// A `[this_ptr]` decoration marks a function parameter that serves as `this` pointer.
- INST(ThisPointerDecoration, this_ptr, 0, 0)
+ /// `[this_ptr]` decoration is also used to mark an `IRFunc` as a non-static function.
+ /// The argument is an integer value that represents the index of the `this` parameter,
+ /// which is always 0.
+ INST(ThisPointerDecoration, this_ptr, 1, 0)
/// A `[format(f)]` decoration specifies that the format of an image should be `f`
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index b13d52981..fb0cc57c7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -165,8 +165,6 @@ IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration)
/// vulkan hit attributes, and should have a location assigned
/// to it.
IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration)
-
-IR_SIMPLE_DECORATION(PolymorphicDecoration)
IR_SIMPLE_DECORATION(ThisPointerDecoration)
@@ -410,9 +408,13 @@ struct IRLookupWitnessMethod : IRInst
{
IRUse witnessTable;
IRUse requirementKey;
+ IRUse interfaceType;
IRInst* getWitnessTable() { return witnessTable.get(); }
IRInst* getRequirementKey() { return requirementKey.get(); }
+ IRInst* getInterfaceType() { return interfaceType.get(); }
+
+ IR_LEAF_ISA(lookup_interface_method)
};
struct IRLookupWitnessTable : IRInst
@@ -1675,7 +1677,8 @@ struct IRBuilder
IRInst* emitLookupInterfaceMethodInst(
IRType* type,
IRInst* witnessTableVal,
- IRInst* interfaceMethodVal);
+ IRInst* interfaceMethodVal,
+ IRType* interfaceType);
IRInst* emitCallInst(
IRType* type,
@@ -1809,9 +1812,16 @@ struct IRBuilder
IRInst* requirementKey,
IRInst* satisfyingVal);
+ IRInterfaceRequirementEntry* createInterfaceRequirementEntry(
+ IRInst* requirementKey,
+ IRInst* requirementVal);
+
// Create an initially empty `struct` type.
IRStructType* createStructType();
+ // Create an IRType representing an `associatedtype` decl.
+ IRAssociatedType* createAssociatedType();
+
// Create an empty `interface` type.
IRInterfaceType* createInterfaceType(UInt operandCount, IRInst* const* operands);
@@ -2160,14 +2170,9 @@ struct IRBuilder
addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode)));
}
- void addPolymorphicDecoration(IRInst* value)
- {
- addDecoration(value, kIROp_PolymorphicDecoration);
- }
-
- void addThisPointerDecoration(IRInst* value)
+ void addThisPointerDecoration(IRInst* value, int paramIndex)
{
- addDecoration(value, kIROp_ThisPointerDecoration);
+ addDecoration(value, kIROp_ThisPointerDecoration, getIntValue(getIntType(), paramIndex));
}
void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 3f51aa876..4e6ad74a4 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -228,6 +228,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
case kIROp_StructKey:
case kIROp_GlobalGenericParam:
case kIROp_WitnessTable:
+ case kIROp_InterfaceType:
case kIROp_TaggedUnionType:
return cloneGlobalValue(this, originalValue);
@@ -607,8 +608,7 @@ IRInterfaceType* cloneInterfaceTypeImpl(
auto clonedInterface = builder->createInterfaceType(originalInterface->getOperandCount(), nullptr);
for (UInt i = 0; i < originalInterface->getOperandCount(); i++)
{
- auto clonedKey = findClonedValue(context, originalInterface->getOperand(i));
- SLANG_ASSERT(clonedKey);
+ auto clonedKey = cloneValue(context, originalInterface->getOperand(i));
clonedInterface->setOperand(i, clonedKey);
}
cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface);
@@ -628,6 +628,7 @@ void cloneGlobalValueWithCodeCommon(
cloneDecorations(context, clonedValue, originalValue);
cloneExtraDecorations(context, clonedValue, originalValues);
+ clonedValue->setFullType((IRType*)cloneValue(context, originalValue->getFullType()));
// We will walk through the blocks of the function, and clone each of them.
//
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index f6340a633..fe0fa3364 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -16,6 +16,7 @@ namespace Slang
IRModule* module;
Dictionary<IRInst*, IRInst*> loweredGenericFunctions;
+ HashSet<IRInterfaceType*> loweredInterfaceTypes;
SharedIRBuilder sharedBuilderStorage;
@@ -45,6 +46,20 @@ namespace Slang
workListSet.Add(inst);
}
+ bool isPolymorphicType(IRInst* typeInst)
+ {
+ if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getFullType()))
+ return true;
+ switch (typeInst->op)
+ {
+ case kIROp_AssociatedType:
+ case kIROp_InterfaceType:
+ return true;
+ default:
+ return false;
+ }
+ }
+
IRInst* lowerGenericFunction(IRInst* genericValue)
{
IRInst* result = nullptr;
@@ -64,6 +79,7 @@ namespace Slang
builder.sharedBuilder = &sharedBuilderStorage;
builder.setInsertBefore(genericParent);
auto loweredFunc = cloneInstAndOperands(&cloneEnv, &builder, func);
+ loweredFunc->setFullType(lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->typeUse.get())));
List<IRInst*> clonedParams;
for (auto genericParam : genericParent->getParams())
{
@@ -82,7 +98,7 @@ namespace Slang
// Turn generic parameters into void pointers.
for (auto param : cast<IRFunc>(loweredFunc)->getParams())
{
- if (param->findDecoration<IRPolymorphicDecoration>())
+ if (isPolymorphicType(param->getFullType()))
{
param->setFullType(builder.getPtrType(builder.getVoidType()));
}
@@ -91,6 +107,106 @@ namespace Slang
return loweredFunc;
}
+ IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal)
+ {
+ List<IRInst*> genericParamTypes;
+ for (auto genericParam : genericVal->getParams())
+ {
+ if (isPolymorphicType(genericParam->getFullType()))
+ {
+ genericParamTypes.add(builder->getPtrType(builder->getVoidType()));
+ }
+ else
+ {
+ genericParamTypes.add(genericParam->getFullType());
+ }
+ }
+
+ auto innerType = (IRFuncType*)lowerFuncType(
+ builder,
+ cast<IRFuncType>(findGenericReturnVal(genericVal)),
+ genericParamTypes.getCount());
+
+ for (int i = 0; i < genericParamTypes.getCount(); i++)
+ {
+ innerType->setOperand(
+ innerType->getOperandCount() - genericParamTypes.getCount() + i,
+ genericParamTypes[i]);
+ }
+
+ return innerType;
+ }
+
+ IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, UInt additionalParamCount = 0)
+ {
+ List<IRInst*> newOperands;
+ bool translated = false;
+ for (UInt i = 0; i < funcType->getOperandCount(); i++)
+ {
+ auto paramType = funcType->getOperand(i);
+ if (isPolymorphicType(paramType))
+ {
+ newOperands.add(builder->getPtrType(builder->getVoidType()));
+ translated = true;
+ }
+ else if (paramType->op == kIROp_Specialize)
+ {
+ // TODO: handle static specialized type here.
+ // For now treat all specialized types as dynamic.
+ // In the future, we need to turn things like Array<IDynamic> into Array<void*>.
+ newOperands.add(builder->getPtrType(builder->getVoidType()));
+ translated = true;
+ }
+ else
+ {
+ newOperands.add(paramType);
+ }
+ }
+ if (!translated)
+ return funcType;
+ for (UInt i = 0; i < additionalParamCount; i++)
+ {
+ newOperands.add(nullptr);
+ }
+ auto newFuncType = builder->getFuncType(
+ newOperands.getCount() - 1,
+ (IRType**)(newOperands.begin() + 1),
+ (IRType*)newOperands[0]);
+
+ IRCloneEnv cloneEnv;
+ cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType);
+ return newFuncType;
+ }
+
+ IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType)
+ {
+ if (loweredInterfaceTypes.Contains(interfaceType))
+ return interfaceType;
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(interfaceType);
+
+ // Translate IRFuncType in interface requirements.
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ {
+ if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)))
+ {
+ if (auto funcType = as<IRFuncType>(entry->getRequirementVal()))
+ {
+ entry->requirementVal.set(lowerFuncType(&builder, funcType));
+ }
+ else if (auto genericFuncType = as<IRGeneric>(entry->getRequirementVal()))
+ {
+ entry->requirementVal.set(lowerGenericFuncType(&builder, genericFuncType));
+ }
+ }
+ }
+
+ loweredInterfaceTypes.Add(interfaceType);
+ return interfaceType;
+ }
+
void processInst(IRInst* inst)
{
if (auto callInst = as<IRCall>(inst))
@@ -98,25 +214,53 @@ namespace Slang
// If we see a call(specialize(gFunc, Targs), args),
// translate it into call(gFunc, args, Targs).
auto funcOperand = callInst->getOperand(0);
+ IRInst* loweredFunc = nullptr;
if (auto specializeInst = as<IRSpecialize>(funcOperand))
{
- auto loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
- if (loweredFunc == specializeInst->getOperand(0))
+ auto funcToSpecialize = specializeInst->getOperand(0);
+ List<IRType*> paramTypes;
+ if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
{
- // This is an intrinsic function, don't transform.
- return;
+ // The callee is a result of witness table lookup, we will only
+ // translate the call.
+ IRInst* callee = nullptr;
+ auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(interfaceLookup->getInterfaceType()));
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ {
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
+ if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
+ {
+ callee = entry->getRequirementVal();
+ break;
+ }
+ }
+ auto funcType = cast<IRFuncType>(callee);
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ paramTypes.add(funcType->getParamType(i));
+ loweredFunc = funcToSpecialize;
+ }
+ else
+ {
+ loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
+ if (loweredFunc == specializeInst->getOperand(0))
+ {
+ // This is an intrinsic function, don't transform.
+ return;
+ }
+ for (auto param : as<IRFunc>(loweredFunc)->getParams())
+ paramTypes.add(param->getDataType());
}
+
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedBuilderStorage;
builder->setInsertBefore(inst);
List<IRInst*> args;
- auto pp = as<IRFunc>(loweredFunc)->getParams().begin();
auto voidPtrType = builder->getPtrType(builder->getVoidType());
for (UInt i = 0; i < callInst->getArgCount(); i++)
{
auto arg = callInst->getArg(i);
- if ((*pp)->getDataType() == voidPtrType &&
+ if (paramTypes[i] == voidPtrType &&
arg->getDataType() != voidPtrType)
{
// We are calling a generic function that with an argument of
@@ -132,7 +276,6 @@ namespace Slang
arg);
}
args.add(arg);
- ++pp;
}
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
args.add(specializeInst->getArg(i));
@@ -141,6 +284,28 @@ namespace Slang
callInst->removeAndDeallocate();
}
}
+ else if (auto witnessTable = as<IRWitnessTable>(inst))
+ {
+ // Lower generic functions in witness table.
+ for (auto child : witnessTable->getChildren())
+ {
+ auto entry = as<IRWitnessTableEntry>(child);
+ if (!entry)
+ continue;
+ if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal()))
+ {
+ if (findGenericReturnVal(genericVal)->op == kIROp_Func)
+ {
+ auto loweredFunc = lowerGenericFunction(genericVal);
+ entry->satisfyingVal.set(loweredFunc);
+ }
+ }
+ }
+ }
+ else if (auto interfaceType = as<IRInterfaceType>(inst))
+ {
+ maybeLowerInterfaceType(interfaceType);
+ }
}
void processModule()
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 77011b569..891f4b3e0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2508,14 +2508,16 @@ namespace Slang
IRInst* IRBuilder::emitLookupInterfaceMethodInst(
IRType* type,
IRInst* witnessTableVal,
- IRInst* interfaceMethodVal)
+ IRInst* interfaceMethodVal,
+ IRType* interfaceType)
{
+ IRInst* args[3] = { witnessTableVal , interfaceMethodVal, interfaceType };
auto inst = createInst<IRLookupWitnessMethod>(
this,
kIROp_lookup_interface_method,
type,
- witnessTableVal,
- interfaceMethodVal);
+ 3,
+ args);
addInst(inst);
return inst;
@@ -2811,6 +2813,20 @@ namespace Slang
return entry;
}
+ IRInterfaceRequirementEntry* IRBuilder::createInterfaceRequirementEntry(
+ IRInst* requirementKey,
+ IRInst* requirementVal)
+ {
+ IRInterfaceRequirementEntry* entry = createInst<IRInterfaceRequirementEntry>(
+ this,
+ kIROp_InterfaceRequirementEntry,
+ nullptr,
+ requirementKey,
+ requirementVal);
+ addGlobalValue(this, entry);
+ return entry;
+ }
+
IRStructType* IRBuilder::createStructType()
{
IRStructType* structType = createInst<IRStructType>(
@@ -2821,6 +2837,16 @@ namespace Slang
return structType;
}
+ IRAssociatedType* IRBuilder::createAssociatedType()
+ {
+ IRAssociatedType* associatedType = createInst<IRAssociatedType>(
+ this,
+ kIROp_AssociatedType,
+ nullptr);
+ addGlobalValue(this, associatedType);
+ return associatedType;
+ }
+
IRInterfaceType* IRBuilder::createInterfaceType(UInt operandCount, IRInst* const* operands)
{
IRInterfaceType* interfaceType = createInst<IRInterfaceType>(
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 3c9a15650..b41c94e7f 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1190,6 +1190,25 @@ struct IRStructType : IRType
IR_LEAF_ISA(StructType)
};
+struct IRAssociatedType : IRType
+{
+ IR_LEAF_ISA(AssociatedType)
+};
+
+struct IRInterfaceRequirementEntry : IRInst
+{
+ // The AST-level requirement
+ IRUse requirementKey;
+
+ // The IR-level value that represents the declaration of the requirement
+ IRUse requirementVal;
+
+ IRInst* getRequirementKey() { return getOperand(0); }
+ IRInst* getRequirementVal() { return getOperand(1); }
+
+ IR_LEAF_ISA(InterfaceRequirementEntry);
+};
+
struct IRInterfaceType : IRType
{
IR_LEAF_ISA(InterfaceType)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ea04ea85c..ff356fd48 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -401,6 +401,18 @@ struct IRGenContext
{
return shared->m_mainModuleDecl;
}
+
+ LoweredValInfo* findLoweredDecl(Decl* decl)
+ {
+ IRGenEnv* envToFindIn = env;
+ while (envToFindIn)
+ {
+ if (auto rs = envToFindIn->mapDeclToValue.TryGetValue(decl))
+ return rs;
+ envToFindIn = envToFindIn->outer;
+ }
+ return nullptr;
+ }
};
void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value)
@@ -986,6 +998,8 @@ IRStructKey* getInterfaceRequirementKey(
IRGenContext* context,
Decl* requirementDecl)
{
+ if (auto genericDecl = as<GenericDecl>(requirementDecl))
+ return getInterfaceRequirementKey(context, genericDecl->inner);
IRStructKey* requirementKey = nullptr;
if(context->shared->interfaceRequirementKeys.TryGetValue(requirementDecl, requirementKey))
{
@@ -1059,7 +1073,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
nullptr,
baseWitnessTable,
- requirementKey));
+ requirementKey,
+ lowerType(context, val->subToMid->sup)));
}
LoweredValInfo visitTaggedUnionSubtypeWitness(
@@ -1240,7 +1255,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto caseFunc = subBuilder->emitLookupInterfaceMethodInst(
caseFuncType,
caseWitnessTable,
- irReqKey);
+ irReqKey,
+ irWitnessTableBaseType);
// We are going to emit a `call` to the satisfying value
// for the case type, so we will collect the arguments for that call.
@@ -4520,7 +4536,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable)
{
auto subBuilder = subContext->irBuilder;
-
+
for(auto entry : astWitnessTable->requirementDictionary)
{
auto requiredMemberDecl = entry.Key;
@@ -5275,11 +5291,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// a witness table for the interface type's conformance
// to its own interface.
//
- List<IRStructKey*> requirementKeys;
+ NestedContext nestedContext(this);
+ auto subBuilder = nestedContext.getBuilder();
+ auto subContext = nestedContext.getContext();
+ List<IRInterfaceRequirementEntry*> requirementEntries;
+
for (auto requirementDecl : decl->members)
{
- requirementKeys.add(getInterfaceRequirementKey(requirementDecl));
-
+ auto key = getInterfaceRequirementKey(requirementDecl);
+ auto entry = subBuilder->createInterfaceRequirementEntry(key, nullptr);
+ requirementEntries.add(entry);
// As a special case, any type constraints placed
// on an associated type will *also* need to be turned
// into requirement keys for this interface.
@@ -5287,22 +5308,20 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
{
- requirementKeys.add(getInterfaceRequirementKey(constraintDecl));
+ auto constraintKey = getInterfaceRequirementKey(constraintDecl);
+ requirementEntries.add(
+ subBuilder->createInterfaceRequirementEntry(constraintKey,
+ lowerType(context, constraintDecl->getSup().type)));
}
}
}
-
- NestedContext nestedContext(this);
- auto subBuilder = nestedContext.getBuilder();
- auto subContext = nestedContext.getContext();
-
// Emit any generics that should wrap the actual type.
emitOuterGenerics(subContext, decl, decl);
IRInterfaceType* irInterface = subBuilder->createInterfaceType(
- requirementKeys.getCount(),
- reinterpret_cast<IRInst**>(requirementKeys.getBuffer()));
+ requirementEntries.getCount(),
+ reinterpret_cast<IRInst**>(requirementEntries.getBuffer()));
addNameHint(context, irInterface, decl);
addLinkageDecoration(context, irInterface, decl);
subBuilder->setInsertInto(irInterface);
@@ -5389,63 +5408,76 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Emit any generics that should wrap the actual type.
emitOuterGenerics(subContext, decl, decl);
- IRStructType* irStruct = subBuilder->createStructType();
- addNameHint(context, irStruct, decl);
- addLinkageDecoration(context, irStruct, decl);
+ IRInst* resultType = nullptr;
+ if (as<AssocTypeDecl>(decl))
+ {
+ resultType = subBuilder->createAssociatedType();
+ }
+ else
+ {
+ resultType = subBuilder->createStructType();
+ }
- subBuilder->setInsertInto(irStruct);
+ addNameHint(context, resultType, decl);
+ addLinkageDecoration(context, resultType, decl);
- // A `struct` that inherits from another `struct` must start
- // with a member for the direct base type.
- //
- for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() )
+ if (resultType->op == kIROp_StructType)
{
- auto superType = inheritanceDecl->base;
- if(auto superDeclRefType = as<DeclRefType>(superType))
+ IRStructType* irStruct = (IRStructType*)resultType;
+ subBuilder->setInsertInto(irStruct);
+
+ // A `struct` that inherits from another `struct` must start
+ // with a member for the direct base type.
+ //
+ for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() )
{
- if(auto superStructDeclRef = superDeclRefType->declRef.as<StructDecl>())
+ auto superType = inheritanceDecl->base;
+ if(auto superDeclRefType = as<DeclRefType>(superType))
{
- auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl));
- auto irSuperType = lowerType(context, superType.type);
- subBuilder->createStructField(
- irStruct,
- superKey,
- irSuperType);
+ if(auto superStructDeclRef = superDeclRefType->declRef.as<StructDecl>())
+ {
+ auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl));
+ auto irSuperType = lowerType(context, superType.type);
+ subBuilder->createStructField(
+ irStruct,
+ superKey,
+ irSuperType);
+ }
}
}
- }
- for (auto fieldDecl : decl->getMembersOfType<VarDeclBase>())
- {
- if (fieldDecl->hasModifier<HLSLStaticModifier>())
+ for (auto fieldDecl : decl->getMembersOfType<VarDeclBase>())
{
- // A `static` field is actually a global variable,
- // and we should emit it as such.
- ensureDecl(context, fieldDecl);
- continue;
- }
-
- // Each ordinary field will need to turn into a struct "key"
- // that is used for fetching the field.
- IRInst* fieldKeyInst = getSimpleVal(context,
- ensureDecl(context, fieldDecl));
- auto fieldKey = as<IRStructKey>(fieldKeyInst);
- SLANG_ASSERT(fieldKey);
-
- // Note: we lower the type of the field in the "sub"
- // context, so that any generic parameters that were
- // set up for the type can be referenced by the field type.
- IRType* fieldType = lowerType(
- subContext,
- fieldDecl->getType());
+ if (fieldDecl->hasModifier<HLSLStaticModifier>())
+ {
+ // A `static` field is actually a global variable,
+ // and we should emit it as such.
+ ensureDecl(context, fieldDecl);
+ continue;
+ }
- // Then, the parent `struct` instruction itself will have
- // a "field" instruction.
- subBuilder->createStructField(
- irStruct,
- fieldKey,
- fieldType);
+ // Each ordinary field will need to turn into a struct "key"
+ // that is used for fetching the field.
+ IRInst* fieldKeyInst = getSimpleVal(context,
+ ensureDecl(context, fieldDecl));
+ auto fieldKey = as<IRStructKey>(fieldKeyInst);
+ SLANG_ASSERT(fieldKey);
+
+ // Note: we lower the type of the field in the "sub"
+ // context, so that any generic parameters that were
+ // set up for the type can be referenced by the field type.
+ IRType* fieldType = lowerType(
+ subContext,
+ fieldDecl->getType());
+
+ // Then, the parent `struct` instruction itself will have
+ // a "field" instruction.
+ subBuilder->createStructField(
+ irStruct,
+ fieldKey,
+ fieldType);
+ }
}
// There may be members not handled by the above logic (e.g.,
@@ -5455,10 +5487,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Instead we will force emission of all children of aggregate
// type declarations later, from the top-level emit logic.
- irStruct->moveToEnd();
- addTargetIntrinsicDecorations(irStruct, decl);
+ resultType->moveToEnd();
+ addTargetIntrinsicDecorations(resultType, decl);
- return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irStruct));
+ return LoweredValInfo::simple(finishOuterGenerics(subBuilder, resultType));
}
LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl)
@@ -5995,29 +6027,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return as<IRStringLit>(builder->getStringValue(stringLitExpr->value.getUnownedSlice()));
}
- LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
+ void _lowerFuncResultAndParameterTypes(
+ ParameterLists& parameterLists,
+ List<IRType*>& paramTypes,
+ IRType*& irResultType,
+ IRBuilder* subBuilder,
+ IRGenContext* subContext,
+ FunctionDeclBase* decl)
{
- // We are going to use a nested builder, because we will
- // change the parent node that things get nested into.
- //
- NestedContext nestedContext(this);
- auto subBuilder = nestedContext.getBuilder();
- auto subContext = nestedContext.getContext();
-
- // The actual `IRFunction` that we emit needs to be nested
- // inside of one `IRGeneric` for every outer `GenericDecl`
- // in the declaration hierarchy.
-
- emitOuterGenerics(subContext, decl, decl);
-
// Collect the parameter lists we will use for our new function.
- ParameterLists parameterLists;
collectParameterLists(decl, &parameterLists, kParameterListCollectMode_Default);
- // TODO: if there are any generic parameters in the collected list, then
- // we need to output an IR function with generic parameters (or a generic
- // with a nested function... the exact representation is still TBD).
-
// In most cases the return type for a declaration can be read off the declaration
// itself, but things get a bit more complicated when we have to deal with
// accessors for subscript declarations (and eventually for properties).
@@ -6036,14 +6056,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- // need to create an IR function here
-
- IRFunc* irFunc = subBuilder->createFunc();
- addNameHint(context, irFunc, decl);
- addLinkageDecoration(context, irFunc, decl);
-
- List<IRType*> paramTypes;
-
for( auto paramInfo : parameterLists.params )
{
IRType* irParamType = lowerType(subContext, paramInfo.type);
@@ -6054,10 +6066,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Simple case of a by-value input parameter.
break;
- // If the parameter is declared `out` or `inout`,
- // then we will represent it with a pointer type in
- // the IR, but we will use a specialized pointer
- // type that encodes the parameter direction information.
+ // If the parameter is declared `out` or `inout`,
+ // then we will represent it with a pointer type in
+ // the IR, but we will use a specialized pointer
+ // type that encodes the parameter direction information.
case kParameterDirection_Out:
irParamType = subBuilder->getOutType(irParamType);
break;
@@ -6084,7 +6096,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
paramTypes.add(irParamType);
}
- auto irResultType = lowerType(subContext, declForReturnType->returnType);
+ irResultType = lowerType(subContext, declForReturnType->returnType);
if (auto setterDecl = as<SetterDecl>(decl))
{
@@ -6107,11 +6119,83 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// being accessed, rather than a simple value.
irResultType = subBuilder->getPtrType(irResultType);
}
+ }
- auto irFuncType = subBuilder->getFuncType(
+ IRFuncType* _lowerFuncTypeImpl(
+ ParameterLists& parameterLists,
+ List<IRType*>& paramTypes,
+ IRType*& irResultType,
+ IRBuilder* builder,
+ IRGenContext* irGenContext,
+ FunctionDeclBase* decl)
+ {
+ _lowerFuncResultAndParameterTypes(
+ parameterLists,
+ paramTypes,
+ irResultType,
+ builder,
+ irGenContext,
+ decl);
+
+ auto irFuncType = builder->getFuncType(
paramTypes.getCount(),
paramTypes.getBuffer(),
irResultType);
+
+ if (parameterLists.params.getCount() && parameterLists.params[0].isThisParam)
+ builder->addThisPointerDecoration(irFuncType, 0);
+ return irFuncType;
+ }
+
+ IRInst* lowerFuncType(FunctionDeclBase* decl)
+ {
+ NestedContext nestedContextFuncType(this);
+ auto funcTypeBuilder = nestedContextFuncType.getBuilder();
+ auto funcTypeContext = nestedContextFuncType.getContext();
+
+ emitOuterGenerics(funcTypeContext, decl, decl);
+
+ ParameterLists parameterLists;
+ List<IRType*> paramTypes;
+ IRType* irResultType = nullptr;
+ auto irFuncType = _lowerFuncTypeImpl(
+ parameterLists,
+ paramTypes,
+ irResultType,
+ funcTypeBuilder,
+ funcTypeContext,
+ decl);
+
+ return finishOuterGenerics(funcTypeBuilder, irFuncType);
+ }
+
+ LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
+ {
+ // We are going to use a nested builder, because we will
+ // change the parent node that things get nested into.
+ //
+ NestedContext nestedContextFunc(this);
+ auto subBuilder = nestedContextFunc.getBuilder();
+ auto subContext = nestedContextFunc.getContext();
+
+ emitOuterGenerics(subContext, decl, decl);
+
+ // need to create an IR function here
+
+ IRFunc* irFunc = subBuilder->createFunc();
+ addNameHint(context, irFunc, decl);
+ addLinkageDecoration(context, irFunc, decl);
+
+ ParameterLists parameterLists;
+ List<IRType*> paramTypes;
+ IRType* irResultType = nullptr;
+ auto irFuncType = _lowerFuncTypeImpl(
+ parameterLists,
+ paramTypes,
+ irResultType,
+ subBuilder,
+ subContext,
+ decl);
irFunc->setFullType(irFuncType);
subBuilder->setInsertInto(irFunc);
@@ -6251,14 +6335,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (paramInfo.isThisParam)
{
subContext->thisVal = paramVal;
- subBuilder->addThisPointerDecoration(irParam);
- }
-
- // Add a [polymorphic] decoration for generic-typed parameters.
- if (as<IRParam>(irParamType) &&
- as<IRTypeType>(irParamType->getFullType()))
- {
- subBuilder->addPolymorphicDecoration(irParam);
+ subBuilder->addThisPointerDecoration(irParam, (int)(paramTypeIndex - 1));
}
}
@@ -6470,7 +6547,53 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// body appear before the function itself in the list
// of global values.
irFunc->moveToEnd();
- return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irFunc));
+
+ // If this function is defined inside an interface, add a reference to the IRFunc from
+ // the interface's type definition.
+ auto finalVal = finishOuterGenerics(subBuilder, irFunc);
+
+ if (auto genericVal = as<IRGeneric>(finalVal))
+ {
+ auto funcType = lowerFuncType(decl);
+ genericVal->typeUse.set(funcType);
+ }
+
+ maybeAssociateToInterfaceType(decl, finalVal);
+
+ return LoweredValInfo::simple(finalVal);
+ }
+
+ void maybeAssociateToInterfaceType(Decl* decl, IRInst* irFuncVal)
+ {
+ auto parent = decl->parentDecl;
+ InterfaceDecl* interfaceDecl = nullptr;
+ while (parent)
+ {
+ interfaceDecl = as<InterfaceDecl>(parent);
+ if (interfaceDecl) break;
+ parent = parent->parentDecl;
+ }
+ if (!interfaceDecl)
+ return;
+ auto loweredVal = context->findLoweredDecl(interfaceDecl);
+ if (!loweredVal)
+ {
+ return;
+ }
+ IRInst* irFuncType = irFuncVal->typeUse.get();
+ auto irInterfaceType = cast<IRInterfaceType>(loweredVal->val);
+ auto key = getInterfaceRequirementKey(decl);
+ for (UInt i = 0; i < irInterfaceType->getOperandCount(); i++)
+ {
+ auto operand = cast<IRInterfaceRequirementEntry>(irInterfaceType->getOperand(i));
+ if (operand->getOperand(0) == key)
+ {
+ operand->setOperand(1, irFuncType);
+ return;
+ }
+ }
+ SLANG_UNREACHABLE("associating interface function declaration:"
+ "requirement not found in the interface type.");
}
LoweredValInfo visitGenericDecl(GenericDecl * genDecl)
@@ -6759,7 +6882,8 @@ LoweredValInfo emitDeclRef(
auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst(
type,
irWitnessTable,
- irRequirementKey);
+ irRequirementKey,
+ lowerType(context, thisTypeSubst->witness->sup));
return LoweredValInfo::simple(irSatisfyingVal);
}
else
diff --git a/tests/compute/dynamic-dispatch-3.slang b/tests/compute/dynamic-dispatch-3.slang
new file mode 100644
index 000000000..7011a2f4e
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-3.slang
@@ -0,0 +1,60 @@
+//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code
+
+// Test dynamic dispatch code gen for static member functions
+// of associated type.
+interface IGetter
+{
+ int getVal();
+};
+interface IAssoc
+{
+ int get();
+ static int getBase<T:IGetter>(T getter);
+}
+interface IInterface
+{
+ associatedtype Assoc : IAssoc;
+ int Compute(int inVal);
+};
+
+struct GetterImpl : IGetter
+{
+ int getVal() { return 1; }
+};
+
+int GenericCompute<T:IInterface>(T obj, int inVal)
+{
+ GetterImpl getter;
+ return obj.Compute(inVal) + T.Assoc.getBase(getter);
+}
+
+struct Impl : IInterface
+{
+ struct Assoc : IAssoc
+ {
+ int val;
+ int get() { return val; }
+ static int getBase<T:IGetter>(T t) { return t.getVal(); }
+ };
+ int base;
+ int Compute(int inVal) { return base + inVal * inVal; }
+};
+
+int test(int inVal)
+{
+ Impl obj;
+ obj.base = 1;
+ return GenericCompute<Impl>(obj, inVal);
+}
+
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer : register(u0);
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = outputBuffer[tid];
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/compute/dynamic-dispatch-3.slang.expected.txt b/tests/compute/dynamic-dispatch-3.slang.expected.txt
new file mode 100644
index 000000000..a6bafb7ca
--- /dev/null
+++ b/tests/compute/dynamic-dispatch-3.slang.expected.txt
@@ -0,0 +1,4 @@
+2
+3
+6
+B