summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-lower-to-ir.cpp112
-rw-r--r--tests/language-feature/constants/static-const-in-generic-interface.slang33
-rw-r--r--tests/language-feature/interfaces/generic-interface-conformance.slang31
-rw-r--r--tools/slang-test/slang-test-main.cpp2
4 files changed, 143 insertions, 35 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 74aa0a0ee..583dcaacc 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8230,6 +8230,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
context->irBuilder->addDecoration(originalKey, op, associatedKey);
}
+ // Given `value` defined as an independent generic of `outerGeneric`, emit IR that specializes it using
+ // the generic params defined in `outerGeneric`.
+ // For example:
+ // ```
+ // interface IFoo<T> { void f(); }
+ // ```
+ // We will lower `IFoo<T>::f` into `%f = IRGeneric(T) { return IRFunc(...) }`
+ // When we lower the interface type `IFoo`, it will become:
+ // ```
+ // %IFoo = IRGeneric(T1) { return IRInterfaceType(???); )
+ // ```
+ // We want the `???` to be `specialize(%f, T1)`.
+ // To do so, we will call `specializeWithOuterGeneric` with `value` = `%f`, and `outerGeneric` = %IFoo.
+ //
+ IRInst* specializeWithOuterGeneric(IRBuilder* irBuilder, IRInst* value, IRGeneric* outerGeneric)
+ {
+ if (!as<IRGeneric>(value))
+ return value;
+ if (!outerGeneric)
+ return value;
+
+ // If `outerGeneric` has a generic parent, we want to recursively specialize value
+ // using the parent generic first.
+ auto parentGeneric = getOuterGeneric(outerGeneric);
+ if (parentGeneric)
+ value = specializeWithOuterGeneric(irBuilder, value, parentGeneric);
+
+ // Now we can specialize `value` using the params defined in `outerGeneric`.
+ List<IRInst*> args;
+ for (auto param : outerGeneric->getParams())
+ args.add(param);
+ return irBuilder->emitSpecializeInst(irBuilder->getGenericKind(), value, args);
+ }
+
LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl)
{
// The members of an interface will turn into the keys that will
@@ -8306,54 +8340,55 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
UInt entryIndex = 0;
- auto addEntry = [&](IRStructKey* requirementKey, Decl* requirementDecl)
+ auto addEntry = [&](IRStructKey* requirementKey, DeclRef<Decl> requirementDeclRef)
{
auto entry = subBuilder->createInterfaceRequirementEntry(
requirementKey,
nullptr);
- if (auto inheritance = as<InheritanceDecl>(requirementDecl))
+ if (auto inheritance = requirementDeclRef.as<InheritanceDecl>())
{
- auto irBaseType = lowerType(context, inheritance->base.type);
+ auto irBaseType = lowerType(subContext, getSup(subContext->astBuilder, inheritance));
auto irWitnessTableType = subBuilder->getWitnessTableType(irBaseType);
entry->setRequirementVal(irWitnessTableType);
}
else
{
- IRInst* requirementVal = ensureDecl(subContext, requirementDecl).val;
- if (requirementVal)
+ auto requirementVal = ensureDecl(subContext, requirementDeclRef.getDecl()).val;
+
+ switch (requirementVal->getOp())
{
- switch (requirementVal->getOp())
- {
- case kIROp_Func:
- case kIROp_Generic:
- {
- // Remove lowered `IRFunc`s since we only care about
- // function types.
- auto reqType = requirementVal->getFullType();
- entry->setRequirementVal(reqType);
- break;
- }
- default:
- entry->setRequirementVal(requirementVal);
- break;
- }
- if (requirementDecl->findModifier<HLSLStaticModifier>())
- {
- getBuilder()->addStaticRequirementDecoration(requirementKey);
- }
+ default:
+ // For the majority of requirements, we only care about its type in an
+ // interface definition, so we store only the type from the lowered IR
+ // in the interface entry.
+ // We need to make sure the type is specialized with the outer generic
+ // parameters in case the interface itself is inside a generic.
+ //
+ requirementVal = specializeWithOuterGeneric(context->irBuilder, requirementVal->getFullType(), outerGeneric);
+ entry->setRequirementVal(requirementVal);
+ break;
+
+ case kIROp_AssociatedType:
+ // For associated types, we will store it directly inside the interface type.
+ entry->setRequirementVal(requirementVal);
+ break;
+ }
+ if (requirementDeclRef.getDecl()->findModifier<HLSLStaticModifier>())
+ {
+ getBuilder()->addStaticRequirementDecoration(requirementKey);
}
}
irInterface->setOperand(entryIndex, entry);
entryIndex++;
// Add addtional requirements for type constraints placed
// on an associated types.
- if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl))
+ if (auto associatedTypeDeclRef = requirementDeclRef.as<AssocTypeDecl>())
{
- for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
+ for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(subContext->astBuilder, associatedTypeDeclRef))
{
- auto constraintKey = getInterfaceRequirementKey(constraintDecl);
+ auto constraintKey = getInterfaceRequirementKey(constraintDeclRef.getDecl());
auto constraintInterfaceType =
- lowerType(context, constraintDecl->getSup().type);
+ lowerType(context, getSup(subContext->astBuilder, constraintDeclRef));
auto witnessTableType =
getBuilder()->getWitnessTableType(constraintInterfaceType);
@@ -8362,16 +8397,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irInterface->setOperand(entryIndex, constraintEntry);
entryIndex++;
- context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry));
+ context->setValue(constraintDeclRef.getDecl(), LoweredValInfo::simple(constraintEntry));
}
}
else
{
CallableDecl* callableDecl = nullptr;
- if (auto genDecl = as<GenericDecl>(requirementDecl))
+ if (auto genDecl = as<GenericDecl>(requirementDeclRef.getDecl()))
callableDecl = as<CallableDecl>(genDecl->inner);
else
- callableDecl = as<CallableDecl>(requirementDecl);
+ callableDecl = as<CallableDecl>(requirementDeclRef.getDecl());
if (callableDecl)
{
// Differentiable functions has additional requirements for the derivatives.
@@ -8384,7 +8419,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Add lowered requirement entry to current decl mapping to prevent
// the function requirements from being lowered again when we get to
// `ensureAllDeclsRec`.
- context->setValue(requirementDecl, LoweredValInfo::simple(entry));
+ context->setValue(requirementDeclRef.getDecl(), LoweredValInfo::simple(entry));
}
};
for (auto requirementDecl : decl->members)
@@ -8400,7 +8435,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
auto accessorKey = getInterfaceRequirementKey(accessorDecl);
if (accessorKey)
- addEntry(accessorKey, accessorDecl);
+ {
+ auto accessorDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, accessorDecl);
+ addEntry(accessorKey, accessorDeclRef);
+ }
}
}
}
@@ -8408,7 +8446,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
- addEntry(requirementKey, requirementDecl);
+ if (auto genericDecl = as<GenericDecl>(requirementDecl))
+ {
+ // We need to form a declref into the inner decls in case of a generic requirement.
+ requirementDecl = getInner(genericDecl);
+ }
+ auto requirementDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, requirementDecl);
+ addEntry(requirementKey, requirementDeclRef);
}
}
diff --git a/tests/language-feature/constants/static-const-in-generic-interface.slang b/tests/language-feature/constants/static-const-in-generic-interface.slang
new file mode 100644
index 000000000..87d8e3be8
--- /dev/null
+++ b/tests/language-feature/constants/static-const-in-generic-interface.slang
@@ -0,0 +1,33 @@
+// static-const-in-generic-interface.slang
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+
+// Test that `static const` variable declarations inside of
+// a generic `interface` type correctly translate to interface requirements.
+
+interface ITest<T:__BuiltinIntegerType>
+{
+ static const T kUserDefinedValue;
+}
+
+struct Impl : ITest<int>
+{
+ static const int kUserDefinedValue = 4;
+}
+
+struct EnsureCompileTimeEval<T : __BuiltinIntegerType>
+{
+ static T getValue<U : ITest<T>>() { return U.kUserDefinedValue; }
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ static const int result = EnsureCompileTimeEval<int>.getValue<Impl>();
+ int outVal = result;
+ // CHECK: 4
+ outputBuffer[0] = outVal;
+}
diff --git a/tests/language-feature/interfaces/generic-interface-conformance.slang b/tests/language-feature/interfaces/generic-interface-conformance.slang
new file mode 100644
index 000000000..9e0510125
--- /dev/null
+++ b/tests/language-feature/interfaces/generic-interface-conformance.slang
@@ -0,0 +1,31 @@
+// Test that we allow type conformances whose base interface is generic.
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-dx11 -compute -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -output-using-type
+
+public interface ITestInterface<Real : IFloat> {
+ Real sample();
+}
+
+struct TestInterfaceImpl<Real : IFloat> : ITestInterface<Real> {
+ Real sample() {
+ return x;
+ }
+ Real x;
+}
+
+//TEST_INPUT: set data = new StructuredBuffer<ITestInterface<float> >[new TestInterfaceImpl<float>{1.0}];
+StructuredBuffer<ITestInterface<float>> data;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
+RWStructuredBuffer<int> outputBuffer;
+
+//TEST_INPUT: type_conformance TestInterfaceImpl<float>:ITestInterface<float> = 3
+
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ let obj = data[0];
+ // CHECK: 1
+ outputBuffer[0] = int(obj.sample());
+} \ No newline at end of file
diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp
index cd3ca3d54..06eae7047 100644
--- a/tools/slang-test/slang-test-main.cpp
+++ b/tools/slang-test/slang-test-main.cpp
@@ -4617,8 +4617,8 @@ SlangResult innerMain(int argc, char** argv)
int main(int argc, char** argv)
{
const SlangResult res = innerMain(argc, argv);
-
slang::shutdown();
+ Slang::RttiInfo::deallocateAll();
#ifdef _MSC_VER
_CrtDumpMemoryLeaks();