summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-12 22:58:22 -0700
committerGitHub <noreply@github.com>2023-04-12 22:58:22 -0700
commitca7bf79df3a3f5f4494912cb0572c36662755b9d (patch)
tree64b14034326be8285c0265e74ad3ed11e29ff062
parent12ec9b832fc74faba7162e54e04f7f48878ea88e (diff)
Combine lookupWitness lowering with specialization. (#2794)
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/slang-ast-type.cpp8
-rw-r--r--source/slang/slang-emit.cpp4
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp16
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp14
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-generics-lowering-context.cpp7
-rw-r--r--source/slang/slang-ir-generics-lowering-context.h16
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h20
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp10
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.cpp407
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.h16
-rw-r--r--source/slang/slang-ir-peephole.cpp44
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp3
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp2
-rw-r--r--source/slang/slang-ir-specialize.cpp14
-rw-r--r--source/slang/slang-ir-specialize.h4
-rw-r--r--source/slang/slang-ir-util.cpp23
-rw-r--r--source/slang/slang-ir-util.h4
-rw-r--r--source/slang/slang-lower-to-ir.cpp5
-rw-r--r--tests/diagnostics/no-type-conformance.slang.expected6
-rw-r--r--tests/ir/dynamic-generic-method-specialize.slang65
-rw-r--r--tests/ir/dynamic-generic-method-specialize.slang.expected.txt2
-rw-r--r--tests/language-server/robustness-7.slang63
-rw-r--r--tests/language-server/robustness-7.slang.expected.txt12
31 files changed, 735 insertions, 60 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index c36dd269d..875c6d399 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -411,6 +411,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-result-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-marshal-native-call.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-metadata.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
@@ -602,6 +603,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-result-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-marshal-native-call.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-metadata.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index f353468e4..9ef8a68b3 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -321,6 +321,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-marshal-native-call.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -890,6 +893,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-marshal-native-call.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 1fed2d52a..72f1764df 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -69,8 +69,8 @@ Type* Type::getCanonicalType()
// TODO(tfoley): worry about thread safety here?
auto canType = et->createCanonicalType();
et->canonicalType = canType;
-
- SLANG_ASSERT(et->canonicalType);
+ if (!et->canonicalType)
+ return getASTBuilder()->getErrorType();
}
return et->canonicalType;
}
@@ -481,7 +481,9 @@ Type* NamedExpressionType::_createCanonicalTypeOverride()
{
if (!innerType)
innerType = getType(m_astBuilder, declRef);
- return innerType->getCanonicalType();
+ if (innerType)
+ return innerType->getCanonicalType();
+ return nullptr;
}
HashCode NamedExpressionType::_getHashCodeOverride()
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ca141bd7a..e2948561c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -374,7 +374,9 @@ Result linkAndOptimizeIR(
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
- changed |= specializeModule(irModule);
+ changed |= specializeModule(irModule, codeGenContext->getSink());
+ if (codeGenContext->getSink()->getErrorCount() != 0)
+ return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
eliminateDeadCode(irModule);
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 79a41f0d5..27bba3fde 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -666,6 +666,7 @@ namespace Slang
case kIROp_IntType:
case kIROp_FloatType:
case kIROp_UIntType:
+ case kIROp_BoolType:
return alignUp(offset, 4) + 4;
case kIROp_UInt64Type:
case kIROp_Int64Type:
@@ -755,6 +756,21 @@ namespace Slang
size += kRTTIHeaderSize;
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
+ case kIROp_AssociatedType:
+ {
+ auto associatedType = cast<IRAssociatedType>(type);
+ SlangInt maxSize = 0;
+ for (UInt i = 0; i < associatedType->getOperandCount(); i++)
+ maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset));
+ return maxSize;
+ }
+ case kIROp_ThisType:
+ {
+ auto thisType = cast<IRThisType>(type);
+ auto interfaceType = thisType->getConstraintType();
+ auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
+ return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
+ }
case kIROp_ExtractExistentialType:
{
auto existentialValue = type->getOperand(0);
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index fe3c70bde..80ee37988 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -410,6 +410,7 @@ struct CFGNormalizationPass
}
case kIROp_loop:
+ case kIROp_Switch:
{
auto breakBlock = normalizeBreakableRegion(terminator);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 6c3d6a934..2be394537 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -877,19 +877,6 @@ InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst)
return InstPair(nullptr, nullptr);
}
-IRInst* ForwardDiffTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
-{
- for (UInt i = 0; i < type->getOperandCount(); i++)
- {
- if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
- {
- if (req->getRequirementKey() == key)
- return req->getRequirementVal();
- }
- }
- return nullptr;
-}
-
InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
{
auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
@@ -1810,6 +1797,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_CastIntToFloat:
case kIROp_CastFloatToInt:
case kIROp_DetachDerivative:
+ case kIROp_GetSequentialID:
return trascribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 8fd271fd8..a9193acbe 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -62,8 +62,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeConst(IRBuilder* builder, IRInst* origInst);
- IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
-
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 04d5560d9..363572c86 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -771,6 +771,7 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialWitnessTable:
case kIROp_undefined:
+ case kIROp_GetSequentialID:
return false;
case kIROp_GetElement:
case kIROp_FieldExtract:
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 9cbea7873..81c0ab235 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -645,19 +645,6 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
return nullptr;
}
-IRInst* AutoDiffTranscriberBase::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
-{
- for (UInt i = 0; i < type->getOperandCount(); i++)
- {
- if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
- {
- if (req->getRequirementKey() == key)
- return req->getRequirementVal();
- }
- }
- return nullptr;
-}
-
InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* origParam)
{
auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType());
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index d5070689e..86af2fb8e 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -111,8 +111,6 @@ struct AutoDiffTranscriberBase
IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType);
- IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
-
IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp
index 8e726cba1..0dbc84e51 100644
--- a/source/slang/slang-ir-generics-lowering-context.cpp
+++ b/source/slang/slang-ir-generics-lowering-context.cpp
@@ -339,7 +339,7 @@ namespace Slang
}
}
- List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType)
+ List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType)
{
List<IRWitnessTable*> witnessTables;
for (auto globalInst : module->getGlobalInsts())
@@ -354,6 +354,11 @@ namespace Slang
return witnessTables;
}
+ List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType)
+ {
+ return Slang::getWitnessTablesFromInterfaceType(module, interfaceType);
+ }
+
IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc)
{
SLANG_UNUSED(usageLoc);
diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h
index 0c07a93c9..8030751d0 100644
--- a/source/slang/slang-ir-generics-lowering-context.h
+++ b/source/slang/slang-ir-generics-lowering-context.h
@@ -74,7 +74,7 @@ namespace Slang
IRInst* maybeEmitRTTIObject(IRInst* typeInst);
static IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc);
- IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type);
+ static IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type);
IRType* lowerType(IRBuilder* builder, IRInst* paramType, const Dictionary<IRInst*, IRInst*>& typeMapping, IRType* concreteType);
@@ -86,20 +86,12 @@ namespace Slang
// Get a list of all witness tables whose conformance type is `interfaceType`.
List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRInst* interfaceType);
- IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
- {
- for (auto entry : table->getEntries())
- {
- if (entry->getRequirementKey() == key)
- return entry->getSatisfyingVal();
- }
- return nullptr;
- }
-
- /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`?
+ /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`?
bool doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType, IRIntegerValue* outTypeSize = nullptr, IRIntegerValue* outLimit = nullptr);
};
+ List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType);
+
bool isPolymorphicType(IRInst* typeInst);
// Returns true if typeInst represents a type and should be lowered into
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 8d06c6970..a8ec5a66f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -740,7 +740,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(AnyValueSizeDecoration, AnyValueSize, 1, 0)
INST(SpecializeDecoration, SpecializeDecoration, 0, 0)
INST(SequentialIDDecoration, SequentialIDDecoration, 1, 0)
-
+ INST(StaticRequirementDecoration, StaticRequirementDecoration, 0, 0)
+ INST(DispatchFuncDecoration, DispatchFuncDecoration, 1, 0)
INST(TypeConstraintDecoration, TypeConstraintDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a19896287..bf0a5d4cd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -198,6 +198,14 @@ struct IRAnyValueSizeDecoration : IRDecoration
}
};
+struct IRDispatchFuncDecoration : IRDecoration
+{
+ enum { kOp = kIROp_DispatchFuncDecoration };
+ IR_LEAF_ISA(DispatchFuncDecoration)
+
+ IRInst* getFunc() { return getOperand(0); }
+};
+
struct IRSpecializeDecoration : IRDecoration
{
enum { kOp = kIROp_SpecializeDecoration };
@@ -265,6 +273,7 @@ IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration)
/// to it.
IR_SIMPLE_DECORATION(VulkanHitObjectAttributesDecoration)
+
struct IRRequireGLSLVersionDecoration : IRDecoration
{
enum { kOp = kIROp_RequireGLSLVersionDecoration };
@@ -326,6 +335,7 @@ IR_SIMPLE_DECORATION(KeepAliveDecoration)
IR_SIMPLE_DECORATION(RequiresNVAPIDecoration)
IR_SIMPLE_DECORATION(NoInlineDecoration)
IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration)
+IR_SIMPLE_DECORATION(StaticRequirementDecoration)
struct IRNVAPIMagicDecoration : IRDecoration
{
@@ -3913,6 +3923,16 @@ public:
addDecoration(inst, kIROp_AnyValueSizeDecoration, getIntValue(getIntType(), value));
}
+ void addDispatchFuncDecoration(IRInst* inst, IRInst* func)
+ {
+ addDecoration(inst, kIROp_DispatchFuncDecoration, func);
+ }
+
+ void addStaticRequirementDecoration(IRInst* inst)
+ {
+ addDecoration(inst, kIROp_StaticRequirementDecoration);
+ }
+
void addSpecializeDecoration(IRInst* inst)
{
addDecoration(inst, kIROp_SpecializeDecoration);
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index e45b20563..12be27f07 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -28,14 +28,18 @@ namespace Slang
auto genericParent = as<IRGeneric>(genericValue);
SLANG_ASSERT(genericParent);
SLANG_ASSERT(genericParent->getDataType());
- auto func = as<IRFunc>(findGenericReturnVal(genericParent));
+ auto genericRetVal = findGenericReturnVal(genericParent);
+ auto func = as<IRFunc>(genericRetVal);
if (!func)
{
// Nested generic functions are supposed to be flattened before entering
// this pass. The reason we are still seeing them must be that they are
// intrinsic functions. In this case we ignore the function.
- SLANG_ASSERT(findInnerMostGenericReturnVal(genericParent)
- ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr);
+ if (as<IRGeneric>(genericRetVal))
+ {
+ SLANG_ASSERT(findInnerMostGenericReturnVal(genericParent)
+ ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr);
+ }
return genericValue;
}
SLANG_ASSERT(func);
diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp
new file mode 100644
index 000000000..1a1900cd2
--- /dev/null
+++ b/source/slang/slang-ir-lower-witness-lookup.cpp
@@ -0,0 +1,407 @@
+// slang-ir-lower-generic-existential.cpp
+
+#include "slang-ir-lower-witness-lookup.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-generics-lowering-context.h"
+
+namespace Slang
+{
+
+struct WitnessLookupLoweringContext
+{
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ Dictionary<IRStructKey*, IRInst*> witnessDispatchFunctions;
+
+ void init()
+ {
+ // Reconstruct the witness dispatch functions map.
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto key = as<IRStructKey>(inst))
+ {
+ for (auto decor : key->getDecorations())
+ {
+ if (auto witnessDispatchFunc = as<IRDispatchFuncDecoration>(decor))
+ {
+ witnessDispatchFunctions.Add(key, witnessDispatchFunc->getFunc());
+ }
+ }
+ }
+ }
+ }
+
+ bool hasAssocType(IRInst* type)
+ {
+ if (!type)
+ return false;
+ if (type->getOp() == kIROp_AssociatedType)
+ return true;
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ if (hasAssocType(type->getOperand(i)))
+ return true;
+ }
+ return false;
+ }
+
+ IRType* translateType(IRBuilder builder, IRInst* type)
+ {
+ if (!type)
+ return nullptr;
+ if (auto genType = as<IRGeneric>(type))
+ {
+ IRCloneEnv cloneEnv;
+ builder.setInsertBefore(genType);
+ auto newGeneric = as<IRGeneric>(cloneInst(&cloneEnv, &builder, genType));
+ newGeneric->setFullType(builder.getGenericKind());
+ auto retVal = findGenericReturnVal(newGeneric);
+ builder.setInsertBefore(retVal);
+ auto translated = translateType(builder, retVal);
+ retVal->replaceUsesWith(translated);
+ return (IRType*)newGeneric;
+ }
+ else if (auto thisType = as<IRThisType>(type))
+ {
+ return (IRType*)thisType->getConstraintType();
+ }
+ else if (auto assocType = as<IRAssociatedType>(type))
+ {
+ return assocType;
+ }
+
+ if (as<IRBasicType>(type))
+ return (IRType*)type;
+
+ switch (type->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_StructType:
+ case kIROp_ClassType:
+ case kIROp_InterfaceType:
+ return (IRType*)type;
+ default:
+ {
+ List<IRInst*> translatedOperands;
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ translatedOperands.add(translateType(builder, type->getOperand(i)));
+ }
+ auto translated = builder.emitIntrinsicInst(
+ builder.getTypeKind(),
+ type->getOp(),
+ (UInt)translatedOperands.getCount(),
+ translatedOperands.getBuffer());
+ return (IRType*)translated;
+ }
+ }
+ }
+
+ IRInst* findOrCreateDispatchFunc(IRLookupWitnessMethod* lookupInst)
+ {
+ IRInst* func = nullptr;
+ auto requirementKey = cast<IRStructKey>(lookupInst->getRequirementKey());
+ if (witnessDispatchFunctions.TryGetValue(requirementKey, func))
+ {
+ return func;
+ }
+
+ auto witnessTableOperand = lookupInst->getWitnessTable();
+ auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableOperand->getDataType());
+ SLANG_RELEASE_ASSERT(witnessTableType);
+ auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(witnessTableType->getConformanceType()));
+ SLANG_RELEASE_ASSERT(interfaceType);
+ if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ return nullptr;
+ auto requirementType = findInterfaceRequirement(interfaceType, requirementKey);
+ SLANG_RELEASE_ASSERT(requirementType);
+
+ // We only lower non-static function requirement lookups for now.
+ // Our front end will stick a StaticRequirementDecoration on the IRStructKey for static member requirements.
+ if (lookupInst->getRequirementKey()->findDecoration<IRStaticRequirementDecoration>())
+ return nullptr;
+ auto interfaceMethodFuncType = as<IRFuncType>(getResolvedInstForDecorations(requirementType));
+ if (interfaceMethodFuncType)
+ {
+ // Detect cases that we currently does not support and exit.
+
+ // If this is a non static function requirement, we should
+ // make sure the first parameter is the interface type. If not, something has gone wrong.
+ if (interfaceMethodFuncType->getParamCount() == 0)
+ return nullptr;
+ if (!as<IRThisType>(unwrapAttributedType(interfaceMethodFuncType->getParamType(0))))
+ return nullptr;
+
+ // The function has any associated type parameter, we currently can't lower it early in this pass.
+ // We will lower it in the catch all generic lowering pass.
+ for (UInt i = 1; i < interfaceMethodFuncType->getParamCount(); i++)
+ {
+ if (hasAssocType(interfaceMethodFuncType->getParamType(i)))
+ return nullptr;
+ }
+
+ // If return type is a composite type containing an assoc type, we won't lower it now.
+ // Supporting general use of assoc type is possible, but would require more complex logic
+ // in this pass to marshal things to and from existential types.
+ if (interfaceMethodFuncType->getResultType()->getOp() != kIROp_AssociatedType &&
+ hasAssocType(interfaceMethodFuncType->getResultType()))
+ return nullptr;
+ }
+ else
+ {
+ return nullptr;
+ }
+
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(getParentFunc(lookupInst));
+
+ // Create a dispatch func.
+ IRFunc* dispatchFunc = nullptr;
+ IRFuncType* dispatchFuncType = nullptr;
+ IRGeneric* parentGeneric = nullptr;
+
+ // If requirementType is a generic, we need to create a new generic that has the same parameters.
+ if (auto genericRequirement = as<IRGeneric>(requirementType))
+ {
+ IRCloneEnv cloneEnv;
+ parentGeneric = as<IRGeneric>(cloneInst(&cloneEnv, &builder, genericRequirement));
+
+ auto returnInst = as<IRReturn>(parentGeneric->getFirstBlock()->getLastInst());
+ SLANG_RELEASE_ASSERT(returnInst);
+ builder.setInsertBefore(returnInst);
+ auto oldDispatchFuncType = as<IRFuncType>(returnInst->getVal());
+ if (!oldDispatchFuncType)
+ return nullptr;
+
+ dispatchFuncType = as<IRFuncType>(translateType(builder, oldDispatchFuncType));
+
+ SLANG_RELEASE_ASSERT(dispatchFuncType);
+
+ dispatchFunc = builder.createFunc();
+ dispatchFunc->setFullType(dispatchFuncType);
+ builder.emitReturn(dispatchFunc);
+ returnInst->removeAndDeallocate();
+
+ parentGeneric->setFullType(translateType(builder, requirementType));
+ }
+ else
+ {
+ dispatchFuncType = as<IRFuncType>(translateType(builder, requirementType));
+ dispatchFunc = builder.createFunc();
+ dispatchFunc->setFullType(dispatchFuncType);
+ }
+
+ // We need to inline this function if the requirement is differentiable,
+ // so that the autodiff pass doesn't need to handle the dispatch function.
+ if (requirementKey->findDecoration<IRForwardDerivativeDecoration>()||
+ requirementKey->findDecoration<IRBackwardDerivativeDecoration>())
+ {
+ builder.addForceInlineDecoration(dispatchFunc);
+ }
+
+ // Collect generic params.
+ List<IRInst*> genericParams;
+ if (parentGeneric)
+ {
+ for (auto param : parentGeneric->getParams())
+ genericParams.add(param);
+ }
+
+ // Emit the body of the dispatch func.
+ builder.setInsertInto(dispatchFunc);
+ auto firstBlock = builder.emitBlock();
+ auto firstBlockBuilder = builder;
+ // Emit parameters.
+ List<IRInst*> params;
+
+ for (UInt i = 0; i < dispatchFuncType->getParamCount(); i++)
+ {
+ params.add(builder.emitParam(dispatchFuncType->getParamType(i)));
+ }
+ auto witness = builder.emitExtractExistentialWitnessTable(params[0]);
+
+ auto witnessTables = getWitnessTablesFromInterfaceType(module, interfaceType);
+ if (witnessTables.getCount() == 0)
+ {
+ // If there is no witness table, we should emit an error.
+ sink->diagnose(lookupInst, Diagnostics::noTypeConformancesFoundForInterface, interfaceType);
+ return nullptr;
+ }
+ else
+ {
+ List<IRInst*> cases;
+ for (auto witnessTable : witnessTables)
+ {
+ IRBlock* block = builder.emitBlock();
+ auto caseValue = firstBlockBuilder.emitGetSequentialIDInst(witnessTable);
+ cases.add(caseValue);
+ cases.add(block);
+ auto entry = findWitnessTableEntry(witnessTable, requirementKey);
+ SLANG_RELEASE_ASSERT(entry);
+ // If the entry is a generic, we need to specialize it.
+ if (auto genericEntry = as<IRGeneric>(entry))
+ {
+ auto specializedFuncType = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ entry->getFullType(),
+ (UInt)genericParams.getCount(),
+ genericParams.getBuffer());
+ entry = builder.emitSpecializeInst(
+ (IRType*)specializedFuncType,
+ entry,
+ (UInt)genericParams.getCount(),
+ genericParams.getBuffer());
+ }
+ auto args = params;
+ // Reinterpret the first arg into the concrete type.
+ args[0] = builder.emitReinterpret(witnessTable->getConcreteType(),
+ builder.emitExtractExistentialValue(builder.emitExtractExistentialType(args[0]), args[0]));
+
+ auto calleeFuncType = as<IRFuncType>(getResolvedInstForDecorations(entry)->getFullType());
+ auto callReturnType = calleeFuncType->getResultType();
+ if (callReturnType->getParent() != module->getModuleInst())
+ {
+ // the return type is dependent on generic parameter, use the type from dispatchFuncType instead.
+ callReturnType = dispatchFuncType->getResultType();
+ }
+
+ auto call = builder.emitCallInst(
+ callReturnType,
+ entry,
+ (UInt)args.getCount(),
+ args.getBuffer());
+ // If result type is an associated type, we need to pack it into an anyValue.
+ if (as<IRAssociatedType>(dispatchFuncType->getResultType()))
+ {
+ call = builder.emitPackAnyValue(dispatchFuncType->getResultType(), call);
+ }
+ builder.emitReturn(call);
+ }
+ builder.setInsertInto(firstBlock);
+ if (witnessTables.getCount() == 1)
+ {
+ builder.emitBranch((IRBlock*)cases[1]);
+ }
+ else
+ {
+ auto witnessId = firstBlockBuilder.emitGetSequentialIDInst(witness);
+ auto breakLabel = builder.emitBlock();
+ builder.emitUnreachable();
+ firstBlockBuilder.emitSwitch(
+ witnessId,
+ breakLabel,
+ (IRBlock*)cases.getLast(),
+ (UInt)(cases.getCount() - 2),
+ cases.getBuffer());
+ }
+ }
+
+ // Stick a decoration on the requirement key so we can find the dispatch func later.
+ IRInst* resultValue = parentGeneric ? (IRInst*)parentGeneric : dispatchFunc;
+ builder.addDispatchFuncDecoration(requirementKey, resultValue);
+
+ // Register the dispatch func to witnessDispatchFunctions dictionary.
+ witnessDispatchFunctions[requirementKey] = resultValue;
+
+ return resultValue;
+ }
+
+ void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* existentialObject)
+ {
+ SLANG_RELEASE_ASSERT(call->getArgCount() != 0);
+ call->setOperand(0, dispatchFunc);
+ call->setOperand(1, existentialObject);
+ }
+
+ bool processWitnessLookup(IRLookupWitnessMethod* lookupInst)
+ {
+ auto witnessTableOperand = lookupInst->getWitnessTable();
+ auto extractInst = as<IRExtractExistentialWitnessTable>(witnessTableOperand);
+ if (!extractInst)
+ return false;
+ auto dispatchFunc = findOrCreateDispatchFunc(lookupInst);
+ if (!dispatchFunc)
+ return false;
+ bool changed = false;
+ auto existentialObject = extractInst->getOperand(0);
+
+ IRBuilder builder(lookupInst);
+ builder.setInsertBefore(lookupInst);
+ traverseUses(lookupInst, [&](IRUse* use)
+ {
+ if (auto specialize = as<IRSpecialize>(use->getUser()))
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ args.add(specialize->getArg(i));
+ auto specializedType = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ dispatchFunc->getFullType(),
+ (UInt)args.getCount(),
+ args.getBuffer());
+ auto newSpecialize = builder.emitSpecializeInst(
+ (IRType*)specializedType,
+ dispatchFunc,
+ (UInt)args.getCount(),
+ args.getBuffer());
+ traverseUses(specialize, [&](IRUse* specializeUse)
+ {
+ if (auto call = as<IRCall>(specializeUse->getUser()))
+ {
+ changed = true;
+ rewriteCallSite(call, newSpecialize, existentialObject);
+ }
+ });
+ }
+ else if (auto call = as<IRCall>(use->getUser()))
+ {
+ changed = true;
+ rewriteCallSite(call, dispatchFunc, existentialObject);
+ }
+ });
+ return changed;
+ }
+
+ bool processFunc(IRFunc* func)
+ {
+ bool changed = false;
+ for (auto bb : func->getBlocks())
+ {
+ for (auto inst : bb->getChildren())
+ {
+ if (auto witnessLookupInst = as<IRLookupWitnessMethod>(inst))
+ {
+ changed |= processWitnessLookup(witnessLookupInst);
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink)
+{
+ bool changed = false;
+ WitnessLookupLoweringContext context;
+ context.module = module;
+ context.sink = sink;
+ context.init();
+
+ for (auto inst : module->getGlobalInsts())
+ {
+ // Process all fully specialized functions and look for
+ // witness lookup instructions. If we see a lookup for a non-static function,
+ // create a dispatch function and replace the lookup with a call to the dispatch function.
+ if (auto func = as<IRFunc>(inst))
+ changed |= context.processFunc(func);
+ }
+ return changed;
+}
+}
diff --git a/source/slang/slang-ir-lower-witness-lookup.h b/source/slang/slang-ir-lower-witness-lookup.h
new file mode 100644
index 000000000..4ae447210
--- /dev/null
+++ b/source/slang/slang-ir-lower-witness-lookup.h
@@ -0,0 +1,16 @@
+// slang-ir-lower-witness-lookup.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+
+ /// Lower calls to a witness lookup into a call to a dispatch function.
+ /// For example, if we see call(witnessLookup(wt, key)), we will create a
+ /// dispatch function that calls into different implementations based on witness table
+ /// ID. The dispatch function will be called instead of witnessLookup.
+ bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink);
+
+}
+
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index ab3f0ceab..2244b480a 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -827,6 +827,50 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_swizzle:
+ {
+ // If we see a swizzle(makeVector) then we can replace it with the values from makeVector.
+ auto makeVector = inst->getOperand(0);
+ if (makeVector->getOp() != kIROp_MakeVector)
+ break;
+ auto swizzle = as<IRSwizzle>(inst);
+ List<IRInst*> vals;
+ auto vectorType = as<IRVectorType>(makeVector->getDataType());
+ auto vectorSize = as<IRIntLit>(vectorType->getElementCount());
+ if (!vectorSize)
+ break;
+ if (makeVector->getOperandCount() != (UInt)vectorSize->getValue())
+ break;
+ for (UInt i = 0; i < swizzle->getElementCount(); i++)
+ {
+ auto index = swizzle->getElementIndex(i);
+ auto intLitIndex = as<IRIntLit>(index);
+ if (!intLitIndex)
+ return;
+ if (intLitIndex->getValue() < (Int)makeVector->getOperandCount())
+ vals.add(makeVector->getOperand((UInt)intLitIndex->getValue()));
+ else
+ return;
+ }
+ if (vals.getCount() == 1)
+ {
+ inst->replaceUsesWith(vals[0]);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ else
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ auto newMakeVector = builder.emitMakeVector(
+ swizzle->getDataType(), (UInt)vals.getCount(), vals.getBuffer());
+ inst->replaceUsesWith(newMakeVector);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ break;
+ }
+
default:
break;
}
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index 4d9b15bbd..074aedb06 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -112,7 +113,7 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
builder->setInsertInto(defaultBlock);
}
- auto callee = sharedContext->findWitnessTableEntry(witnessTable, requirementKey);
+ auto callee = findWitnessTableEntry(witnessTable, requirementKey);
SLANG_ASSERT(callee);
auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
if (callInst->getDataType()->getOp() == kIROp_VoidType)
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index e68ac5b73..1a1186cda 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -75,7 +75,7 @@ struct AssociatedTypeLookupSpecializationContext
builder.setInsertInto(defaultBlock);
}
- auto resultWitnessTable = sharedContext->findWitnessTableEntry(witnessTable, key);
+ auto resultWitnessTable = findWitnessTableEntry(witnessTable, key);
auto resultWitnessTableIDDecoration =
resultWitnessTable->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(resultWitnessTableIDDecoration);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index d2e042363..eb3677653 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-lower-witness-lookup.h"
namespace Slang
{
@@ -43,6 +44,7 @@ struct SpecializationContext
// For convenience, we will keep a pointer to the module
// we are specializing.
IRModule* module;
+ DiagnosticSink* sink;
bool changed = false;
@@ -932,7 +934,11 @@ struct SpecializationContext
}
else
{
- break;
+ // If we run out of specialization opportunities, consider
+ // lower lookupWitnessMethod insts into dynamic dispatch calls.
+ iterChanged = lowerWitnessLookup(module, sink);
+ if (!iterChanged || sink->getErrorCount())
+ break;
}
}
@@ -1323,7 +1329,6 @@ struct SpecializationContext
IRInst* curInst = localWorkList.getLast();
localWorkList.removeLast();
- processedInsts.Remove(curInst);
switch (curInst->getOp())
{
@@ -2329,10 +2334,12 @@ struct SpecializationContext
};
bool specializeModule(
- IRModule* module)
+ IRModule* module,
+ DiagnosticSink* sink)
{
SpecializationContext context;
context.module = module;
+ context.sink = sink;
context.processModule();
return context.changed;
}
@@ -2349,6 +2356,7 @@ void finalizeSpecialization(IRModule* module)
case kIROp_ExistentialFuncSpecializationDictionary:
case kIROp_ExistentialTypeSpecializationDictionary:
case kIROp_GenericSpecializationDictionary:
+ case kIROp_DispatchFuncDecoration:
decor->removeAndDeallocate();
break;
default:
diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h
index 20d65cb67..1feff3e4b 100644
--- a/source/slang/slang-ir-specialize.h
+++ b/source/slang/slang-ir-specialize.h
@@ -4,10 +4,12 @@
namespace Slang
{
struct IRModule;
+class DiagnosticSink;
/// Specialize generic and interface-based code to use concrete types.
bool specializeModule(
- IRModule* module);
+ IRModule* module,
+ DiagnosticSink* sink);
void finalizeSpecialization(IRModule* module);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index c5cebb8b5..83f6735bd 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -698,6 +698,29 @@ bool isPureFunctionalCall(IRCall* call)
return false;
}
+IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
+{
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
+ {
+ if (req->getRequirementKey() == key)
+ return req->getRequirementVal();
+ }
+ }
+ return nullptr;
+}
+
+IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
+{
+ for (auto entry : table->getEntries())
+ {
+ if (entry->getRequirementKey() == key)
+ return entry->getSatisfyingVal();
+ }
+ return nullptr;
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index f8e53c38f..ef7ff47bb 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -194,6 +194,10 @@ void sortBlocksInFunc(IRGlobalValueWithCode* func);
// Remove all linkage decorations from func.
void removeLinkageDecorations(IRGlobalValueWithCode* func);
+
+IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
+
+IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key);
}
#endif
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f0c30dd3c..ada2e043e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7145,6 +7145,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
entry->setRequirementVal(requirementVal);
break;
}
+ if (requirementDecl->findModifier<HLSLStaticModifier>())
+ {
+ getBuilder()->addStaticRequirementDecoration(requirementKey);
+ }
}
}
irInterface->setOperand(entryIndex, entry);
@@ -7807,6 +7811,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// to the newly construct generic value.
typeBuilder.setInsertBefore(parentGeneric);
auto typeGeneric = typeBuilder.emitGeneric();
+ typeGeneric->setFullType(typeBuilder.getGenericKind());
typeBuilder.setInsertInto(typeGeneric);
typeBuilder.emitBlock();
diff --git a/tests/diagnostics/no-type-conformance.slang.expected b/tests/diagnostics/no-type-conformance.slang.expected
index bc38fa7f1..5f5eda6af 100644
--- a/tests/diagnostics/no-type-conformance.slang.expected
+++ b/tests/diagnostics/no-type-conformance.slang.expected
@@ -1,8 +1,8 @@
result code = -1
standard error = {
-tests/diagnostics/no-type-conformance.slang(4): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage.
-interface IFoo
- ^~~~
+tests/diagnostics/no-type-conformance.slang(12): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage.
+ obj.get();
+ ^
}
standard output = {
}
diff --git a/tests/ir/dynamic-generic-method-specialize.slang b/tests/ir/dynamic-generic-method-specialize.slang
new file mode 100644
index 000000000..92ce8158e
--- /dev/null
+++ b/tests/ir/dynamic-generic-method-specialize.slang
@@ -0,0 +1,65 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -profile sm_5_0 -output-using-type
+
+// Test that we can specialize a generic method called through a dynamic interface.
+
+interface IValue
+{
+ float getVal();
+}
+
+struct SimpleVal : IValue
+{
+ float val;
+ float getVal() { return val; }
+}
+
+[anyValueSize(16)]
+interface IInterface
+{
+ associatedtype V : IValue;
+ V run<let N : int>(float arr[N]);
+}
+
+struct Add : IInterface
+{
+ float base;
+ typealias V = SimpleVal;
+ V run<let N : int>(float arr[N])
+ {
+ float sum = base;
+ for (int i = 0; i < N; i++)
+ sum += arr[i];
+ V rs;
+ rs.val = sum;
+ return rs;
+ }
+}
+
+struct Mul : IInterface
+{
+ float base;
+ typealias V = SimpleVal;
+ V run<let N : int>(float arr[N])
+ {
+ float sum = base;
+ for (int i = 0; i < N; i++)
+ sum *= arr[i];
+ V rs;
+ rs.val = sum;
+ return rs;
+ }
+}
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer
+RWStructuredBuffer<float> gOutputBuffer;
+
+//TEST_INPUT:type_conformance Add:IInterface=1
+//TEST_INPUT:type_conformance Mul:IInterface=2
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(1, 1.0); // Add.
+ float arr[3] = { 2, 3, 4 };
+ gOutputBuffer[0] = obj.run(arr).getVal();
+} \ No newline at end of file
diff --git a/tests/ir/dynamic-generic-method-specialize.slang.expected.txt b/tests/ir/dynamic-generic-method-specialize.slang.expected.txt
new file mode 100644
index 000000000..0bc25648a
--- /dev/null
+++ b/tests/ir/dynamic-generic-method-specialize.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+10.0 \ No newline at end of file
diff --git a/tests/language-server/robustness-7.slang b/tests/language-server/robustness-7.slang
new file mode 100644
index 000000000..ebc34e078
--- /dev/null
+++ b/tests/language-server/robustness-7.slang
@@ -0,0 +1,63 @@
+//TEST:LANG_SERVER:
+//HOVER:6,15
+
+// Test that we can specialize a generic method called through a dynamic interface.
+
+interface IValue
+{
+ float getVal();
+}
+
+struct SimpleVal : IValue
+{
+ float val;
+ float getVal() { return val; }
+}
+
+[anyValueSize(16)]
+interface IInterface
+{
+ associatedtype V : IValue;
+ V run<let N : int>(float arr[N]);
+}
+
+struct Add : IInterface
+{
+ float base;
+ typealias V
+ float run<let N : int>(float arr[N])
+ {
+ float sum = base;
+ for (int i = 0; i < N; i++)
+ sum += arr[i];
+ return sum;
+ }
+}
+
+struct Mul : IInterface
+{
+ float base;
+
+ float run<let N : int>(float arr[N])
+ {
+ float sum = base;
+ for (int i = 0; i < N; i++)
+ sum *= arr[i];
+ return sum;
+ }
+}
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer
+RWStructuredBuffer<float> gOutputBuffer;
+
+//TEST_INPUT:type_conformance Add:IInterface=1
+//TEST_INPUT:type_conformance Mul:IInterface=2
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(1, 1.0); // Add.
+ float arr[3] = { 2, 3, 4 };
+ gOutputBuffer[0] = obj.run(arr);
+
+} \ No newline at end of file
diff --git a/tests/language-server/robustness-7.slang.expected.txt b/tests/language-server/robustness-7.slang.expected.txt
new file mode 100644
index 000000000..d5f8ed9f8
--- /dev/null
+++ b/tests/language-server/robustness-7.slang.expected.txt
@@ -0,0 +1,12 @@
+--------
+range: 5,10 - 5,16
+content:
+```
+interface IValue
+```
+
+Test that we can specialize a generic method called through a dynamic interface.
+
+{REDACTED}.slang(6)
+
+