summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h33
-rw-r--r--source/slang/slang-compiler.cpp86
-rwxr-xr-xsource/slang/slang-compiler.h126
-rw-r--r--source/slang/slang-ir-link.cpp4
-rw-r--r--source/slang/slang-ir-specialize.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp64
-rw-r--r--source/slang/slang-lower-to-ir.h11
-rw-r--r--source/slang/slang-parameter-binding.cpp18
-rw-r--r--source/slang/slang.cpp47
-rw-r--r--tests/compute/dynamic-dispatch-16.slang10
-rw-r--r--tools/render-test/shader-input-layout.cpp15
-rw-r--r--tools/render-test/shader-input-layout.h10
-rw-r--r--tools/render-test/slang-support.cpp39
-rw-r--r--tools/render-test/slang-support.h9
14 files changed, 470 insertions, 7 deletions
diff --git a/slang.h b/slang.h
index 138076110..5433b4f37 100644
--- a/slang.h
+++ b/slang.h
@@ -3050,6 +3050,7 @@ namespace slang
typedef ISlangBlob IBlob;
struct IComponentType;
+ struct ITypeConformance;
struct IGlobalSession;
struct IModule;
struct ISession;
@@ -4023,6 +4024,32 @@ namespace slang
*/
virtual SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
SlangCompileRequest** outCompileRequest) = 0;
+
+
+ /** Creates a `IComponentType` that represents a type's conformance to an interface.
+ The retrieved `ITypeConformance` objects can be included in a composite `IComponentType`
+ to explicitly specify which implementation types should be included in the final compiled
+ code. For example, if an module defines `IMaterial` interface and `AMaterial`,
+ `BMaterial`, `CMaterial` types that implements the interface, the user can exclude
+ `CMaterial` implementation from the resulting shader code by explcitly adding
+ `AMaterial:IMaterial` and `BMaterial:IMaterial` conformances to a composite
+ `IComponentType` and get entry point code from it. The resulting code will not have
+ anything related to `CMaterial` in the dynamic dispatch logic. If the user does not
+ explicitly include any `TypeConformances` to an interface type, all implementations to
+ that interface will be included by default. By linking a `ITypeConformance`, the user is
+ also given the opportunity to specify the dispatch ID of the implementation type. If
+ `conformanceIdOverride` is -1, there will be no override behavior and Slang will
+ automatically assign IDs to implementation types. The automatically assigned IDs can be
+ queried via `ISession::getTypeConformanceWitnessSequentialID`.
+
+ Returns SLANG_OK if succeeds, or SLANG_FAIL if `type` does not conform to `interfaceType`.
+ */
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ ITypeConformance** outConformance,
+ SlangInt conformanceIdOverride,
+ ISlangBlob** outDiagnostics) = 0;
};
#define SLANG_UUID_ISession ISession::getTypeGuid()
@@ -4204,6 +4231,12 @@ namespace slang
#define SLANG_UUID_IEntryPoint IEntryPoint::getTypeGuid()
+ struct ITypeConformance : public IComponentType
+ {
+ SLANG_COM_INTERFACE(0x73eb3147, 0xe544, 0x41b5, { 0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0xb })
+ };
+ #define SLANG_UUID_ITypeConformance ITypeConformance::getTypeGuid()
+
/** A module is the granularity of shader code compilation and loading.
In most cases a module corresponds to a single compile "translation unit."
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 997f3fd51..5d0ae2f8d 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -296,6 +296,92 @@ namespace Slang
return empty;
}
+ TypeConformance::TypeConformance(
+ Linkage* linkage,
+ SubtypeWitness* witness,
+ Int confomrmanceIdOverride,
+ DiagnosticSink* sink)
+ : ComponentType(linkage)
+ , m_subtypeWitness(witness)
+ , m_conformanceIdOverride(confomrmanceIdOverride)
+ {
+ addDepedencyFromWitness(witness);
+ m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink);
+ }
+
+ void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness)
+ {
+ if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
+ {
+ auto declModule = getModule(declaredWitness->declRef.getDecl());
+ m_moduleDependency.addDependency(declModule);
+ m_pathDependency.addDependency(declModule);
+ if (m_requirementSet.Add(declModule))
+ {
+ m_requirements.add(declModule);
+ }
+ // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions.
+ }
+ else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness))
+ {
+ addDepedencyFromWitness(transitiveWitness->midToSup);
+ addDepedencyFromWitness(transitiveWitness->subToMid);
+ }
+ else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness))
+ {
+ auto left = as<SubtypeWitness>(conjunctionWitness->leftWitness);
+ if (left)
+ addDepedencyFromWitness(left);
+ auto right = as<SubtypeWitness>(conjunctionWitness->rightWitness);
+ if (right)
+ addDepedencyFromWitness(right);
+ }
+ }
+
+ ISlangUnknown* TypeConformance::getInterface(const Guid& guid)
+ {
+ if (guid == slang::ITypeConformance::getTypeGuid())
+ return static_cast<slang::ITypeConformance*>(this);
+
+ return Super::getInterface(guid);
+ }
+
+ List<Module*> const& TypeConformance::getModuleDependencies()
+ {
+ return m_moduleDependency.getModuleList();
+ }
+
+ List<String> const& TypeConformance::getFilePathDependencies()
+ {
+ return m_pathDependency.getFilePathList();
+ }
+
+ Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); }
+
+ RefPtr<ComponentType> TypeConformance::getRequirement(Index index)
+ {
+ return m_requirements[index];
+ }
+
+ void TypeConformance::acceptVisitor(
+ ComponentTypeVisitor* visitor,
+ ComponentType::SpecializationInfo* specializationInfo)
+ {
+ SLANG_UNUSED(specializationInfo);
+ visitor->visitTypeConformance(this);
+ }
+
+ RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
+ {
+ SLANG_UNUSED(args);
+ SLANG_UNUSED(argCount);
+ SLANG_UNUSED(sink);
+ return nullptr;
+ }
+
//
Profile Profile::lookUp(UnownedStringSlice const& name)
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 1724794cf..727d21c3a 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -850,6 +850,125 @@ namespace Slang
Profile m_profile;
};
+ class TypeConformance
+ : public ComponentType
+ , public slang::ITypeConformance
+ {
+ typedef ComponentType Super;
+
+ public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ TypeConformance(
+ Linkage* linkage,
+ SubtypeWitness* witness,
+ Int confomrmanceIdOverride,
+ DiagnosticSink* sink);
+
+ // Forward `IComponentType` methods
+
+ SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE
+ {
+ return Super::getSession();
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL
+ getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getLayout(targetIndex, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::specialize(
+ specializationArgs,
+ specializationArgCount,
+ outSpecializedComponentType,
+ outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::link(outLinkedComponentType, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointHostCallable(
+ entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ }
+
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
+
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
+
+ /// Get the existential type parameter at `index`.
+ SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE
+ {
+ static SpecializationParam emptyParam;
+ return emptyParam;
+ }
+
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
+ Index getEntryPointCount() SLANG_OVERRIDE { return 0; };
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(index);
+ return nullptr;
+ }
+ String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
+ ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(index);
+ return ShaderParamInfo();
+ }
+
+ SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; }
+ IRModule* getIRModule() { return m_irModule.Ptr(); }
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+ SLANG_OVERRIDE;
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
+ private:
+ SubtypeWitness* m_subtypeWitness;
+ ModuleDependencyList m_moduleDependency;
+ FilePathDependencyList m_pathDependency;
+ List<RefPtr<Module>> m_requirements;
+ HashSet<Module*> m_requirementSet;
+ RefPtr<IRModule> m_irModule;
+ Int m_conformanceIdOverride;
+ void addDepedencyFromWitness(SubtypeWitness* witness);
+ };
+
enum class PassThroughMode : SlangPassThroughIntegral
{
None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler
@@ -1319,6 +1438,12 @@ namespace Slang
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
uint32_t* outId) override;
+ SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ slang::ITypeConformance** outConformance,
+ SlangInt conformanceIdOverride,
+ ISlangBlob** outDiagnostics) override;
SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
SlangCompileRequest** outCompileRequest) override;
@@ -1756,6 +1881,7 @@ namespace Slang
virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0;
virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
virtual void visitSpecialized(SpecializedComponentType* specialized) = 0;
+ virtual void visitTypeConformance(TypeConformance* conformance) = 0;
protected:
// These helpers can be used to recurse into the logical children of a
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 480fe504c..36985efd6 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -274,7 +274,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
default:
{
- // In the deafult case, assume that we have some sort of "hoistable"
+ // In the default case, assume that we have some sort of "hoistable"
// instruction that requires us to create a clone of it.
UInt argCount = originalValue->getOperandCount();
@@ -439,6 +439,8 @@ static void cloneExtraDecorations(
case kIROp_BindExistentialSlotsDecoration:
case kIROp_LayoutDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_SequentialIDDecoration:
if(!clonedInst->findDecorationImpl(decoration->getOp()))
{
cloneInst(context, builder, decoration);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index e9300ae06..9522dbeca 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -2158,6 +2158,11 @@ IRInst* specializeGenericImpl(
if( auto returnValInst = as<IRReturnVal>(ii) )
{
auto specializedVal = findCloneForOperand(&env, returnValInst->getVal());
+
+ // Clone decorations on the orignal `specialize` inst over to the newly specialized
+ // value.
+ cloneInstDecorationsAndChildren(
+ &env, &sharedBuilderStorage, specializeInst, specializedVal);
return specializedVal;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index c7e32072e..8a40dbde9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8255,6 +8255,11 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
RefPtr<IRModule> generateIRForSpecializedComponentType(
@@ -8265,6 +8270,65 @@ RefPtr<IRModule> generateIRForSpecializedComponentType(
return context.process(componentType, sink);
}
+ /// Context for generating IR code to represent a `TypeConformance`
+struct TypeConformanceIRGenContext
+{
+ DiagnosticSink* sink;
+ Linkage* linkage;
+ Session* session;
+ IRGenContext* context;
+ IRBuilder* builder;
+
+ RefPtr<IRModule> process(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* inSink)
+ {
+ sink = inSink;
+
+ linkage = typeConformance->getLinkage();
+ session = linkage->getSessionImpl();
+
+ SharedIRGenContext sharedContextStorage(session, sink, linkage->m_obfuscateCode);
+ SharedIRGenContext* sharedContext = &sharedContextStorage;
+
+ IRGenContext contextStorage(sharedContext, linkage->getASTBuilder());
+ context = &contextStorage;
+
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = session;
+
+ IRBuilder builderStorage;
+ builder = &builderStorage;
+ builder->sharedBuilder = sharedBuilder;
+
+ RefPtr<IRModule> module = builder->createModule();
+ sharedBuilder->module = module;
+
+ builder->setInsertInto(module->getModuleInst());
+
+ context->irBuilder = builder;
+
+ auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness());
+ builder->addPublicDecoration(witness);
+ if (conformanceIdOverride != -1)
+ {
+ builder->addSequentialIDDecoration(witness, conformanceIdOverride);
+ }
+ return module;
+ }
+};
+
+RefPtr<IRModule> generateIRForTypeConformance(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* sink)
+{
+ TypeConformanceIRGenContext context;
+ return context.process(typeConformance, conformanceIdOverride, sink);
+}
RefPtr<IRModule> TargetProgram::getOrCreateIRModuleForLayout(DiagnosticSink* sink)
{
diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h
index dbf2e550a..ce7f9eaf0 100644
--- a/source/slang/slang-lower-to-ir.h
+++ b/source/slang/slang-lower-to-ir.h
@@ -43,5 +43,14 @@ namespace Slang
RefPtr<IRModule> generateIRForSpecializedComponentType(
SpecializedComponentType* componentType,
DiagnosticSink* sink);
-}
+
+ /// Generate an IR module to represent a user specified `TypeConformance` component type.
+ /// The generated IR will include an extern symbol representing the type conformance
+ /// (typically a `IRWitnessTable` or a `specialize(IRWitnessTable)` inst), with a `public`
+ /// decoration to keep the referenced witness table alive during linking.
+ RefPtr<IRModule> generateIRForTypeConformance(
+ TypeConformance* typeConformance,
+ Int conformanceIdOverride,
+ DiagnosticSink* sink);
+ }
#endif
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 9b080a11c..6f9f127f8 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2735,6 +2735,11 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor
SLANG_UNUSED(specializationInfo);
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(module);
@@ -2920,6 +2925,10 @@ struct CollectParametersVisitor : ComponentTypeVisitor
}
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
/// Recursively collect the global shader parameters and entry points in `program`.
@@ -3146,6 +3155,11 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor
auto base = specialized->getBaseComponentType();
_completeBindings(m_context, base, m_counters);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
/// A visitor used by `_completeBindings`.
@@ -3272,6 +3286,10 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor
m_counters->entryPointCounter += specialized->getEntryPointCount();
}
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
static void _completeBindings(
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index ef6872e76..e62d2c9ae 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1043,6 +1043,33 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent
return SLANG_OK;
}
+SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ slang::ITypeConformance** outConformanceComponentType,
+ SlangInt conformanceIdOverride,
+ ISlangBlob** outDiagnostics)
+{
+ RefPtr<TypeConformance> result;
+ DiagnosticSink sink;
+ try
+ {
+ SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+ auto witness =
+ visitor.tryGetSubtypeWitness((Slang::Type*)type, (Slang::Type*)interfaceType);
+ if (auto subtypeWitness = as<SubtypeWitness>(witness))
+ {
+ result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink);
+ }
+ }
+ catch (...)
+ {}
+ sink.getBlobIfNeeded(outDiagnostics);
+ *outConformanceComponentType = result.detach();
+ return result ? SLANG_OK : SLANG_FAIL;
+}
+
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest(
SlangCompileRequest** outCompileRequest)
{
@@ -3041,6 +3068,11 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
@@ -3079,6 +3111,11 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor
m_callback(specialized->getIRModule(), m_userData);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ m_callback(conformance->getIRModule(), m_userData);
+ }
};
void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData)
@@ -3401,6 +3438,11 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
{
visitChildren(specialized);
}
+
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
};
SpecializedComponentType::SpecializedComponentType(
@@ -3616,7 +3658,10 @@ SpecializedComponentType::SpecializedComponentType(
{ visitChildren(composite, specializationInfo); }
void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
{ visitChildren(specialized); }
-
+ void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(conformance);
+ }
EntryPointMangledNameCollector(ASTBuilder* astBuilder):
m_astBuilder(astBuilder)
{
diff --git a/tests/compute/dynamic-dispatch-16.slang b/tests/compute/dynamic-dispatch-16.slang
index 5ceb75cd7..b9fecb966 100644
--- a/tests/compute/dynamic-dispatch-16.slang
+++ b/tests/compute/dynamic-dispatch-16.slang
@@ -19,9 +19,12 @@ struct UserDefinedPackedType
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer
RWStructuredBuffer<float> gOutputBuffer;
-//TEST_INPUT: set gObj = new StructuredBuffer<UserDefinedPackedType>[new UserDefinedPackedType{[1.0, 0.0, 0.0], 0}, new UserDefinedPackedType{[2.0, 3.0, 4.0], 1}];
+//TEST_INPUT: set gObj = new StructuredBuffer<UserDefinedPackedType>[new UserDefinedPackedType{[1.0, 2.0, 3.0], 3}, new UserDefinedPackedType{[2.0, 3.0, 4.0], 4}];
RWStructuredBuffer<UserDefinedPackedType> gObj;
+//TEST_INPUT: type_conformance FloatVal:IInterface = 3
+//TEST_INPUT: type_conformance Float4Val:IInterface = 4
+
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
@@ -35,8 +38,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
gOutputBuffer[0] = result;
}
-// Type must be marked `public` to ensure it is visible in the generated DLL.
-public struct FloatVal : IInterface
+struct FloatVal : IInterface
{
float val;
float run()
@@ -46,7 +48,7 @@ public struct FloatVal : IInterface
};
interface ISomething{void g();}
struct Float4Struct : ISomething { float4 val; void g() {} }
-public struct Float4Val : IInterface
+struct Float4Val : IInterface
{
Float4Struct val;
float run()
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index 2e0741cea..7e6d290c8 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -837,6 +837,17 @@ namespace renderer_test
parentForNewVal->addField(field);
}
+ void parseTypeConformance(Misc::TokenReader& parser)
+ {
+ ShaderInputLayout::TypeConformanceVal conformance;
+ conformance.derivedTypeName = parseTypeName(parser);
+ parser.Read(":");
+ conformance.baseTypeName = parseTypeName(parser);
+ if (parser.AdvanceIf("="))
+ conformance.idOverride = parser.ReadInt();
+ layout->typeConformances.add(conformance);
+ }
+
void parseLine(Misc::TokenReader& parser)
{
if (parser.LookAhead("entryPointSpecializationArg")
@@ -872,6 +883,10 @@ namespace renderer_test
{
parseSetEntry(parser);
}
+ else if (parser.AdvanceIf("type_conformance"))
+ {
+ parseTypeConformance(parser);
+ }
else
{
parseValEntry(parser);
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index 76ebe79b3..fe835a7f1 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -285,6 +285,16 @@ public:
Slang::RefPtr<AggVal> rootVal;
Slang::List<Slang::String> globalSpecializationArgs;
Slang::List<Slang::String> entryPointSpecializationArgs;
+
+ class TypeConformanceVal
+ {
+ public:
+ Slang::String derivedTypeName;
+ Slang::String baseTypeName;
+ Int idOverride = -1;
+ };
+ Slang::List<TypeConformanceVal> typeConformances;
+
int numRenderTargets = 1;
Slang::Index findEntryIndexByName(const Slang::String& name) const;
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index 65f8f244b..f479218f7 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -235,6 +235,37 @@ void ShaderCompilerUtil::Output::reset()
actualEntryPoints = request.entryPoints;
}
+ if (request.typeConformances.getCount())
+ {
+ ComPtr<slang::ISession> session;
+ slangRequest->getSession(session.writeRef());
+ List<ComPtr<slang::ITypeConformance>> typeConformanceComponents;
+ List<slang::IComponentType*> componentsRawPtr;
+ componentsRawPtr.add(linkedSlangProgram.get());
+ auto reflection = slang::ProgramLayout::get(slangRequest);
+ ComPtr<ISlangBlob> outDiagnostic;
+ for (auto& conformance : request.typeConformances)
+ {
+ auto derivedType = reflection->findTypeByName(conformance.derivedTypeName.getBuffer());
+ auto baseType = reflection->findTypeByName(conformance.baseTypeName.getBuffer());
+ ComPtr<slang::ITypeConformance> conformanceComponentType;
+ session->createTypeConformanceComponentType(
+ derivedType,
+ baseType,
+ conformanceComponentType.writeRef(),
+ conformance.idOverride,
+ outDiagnostic.writeRef());
+ typeConformanceComponents.add(conformanceComponentType);
+ componentsRawPtr.add(conformanceComponentType);
+ }
+ ComPtr<slang::IComponentType> newProgram;
+ session->createCompositeComponentType(
+ componentsRawPtr.getBuffer(),
+ componentsRawPtr.getCount(),
+ newProgram.writeRef(),
+ outDiagnostic.writeRef());
+ linkedSlangProgram = newProgram;
+ }
out.set(input.pipelineType, linkedSlangProgram);
return SLANG_OK;
}
@@ -415,6 +446,14 @@ void ShaderCompilerUtil::Output::reset()
}
compileRequest.globalSpecializationArgs = layout.globalSpecializationArgs;
compileRequest.entryPointSpecializationArgs = layout.entryPointSpecializationArgs;
+ for (auto conformance : layout.typeConformances)
+ {
+ ShaderCompileRequest::TypeConformance c;
+ c.derivedTypeName = conformance.derivedTypeName;
+ c.baseTypeName = conformance.baseTypeName;
+ c.idOverride = conformance.idOverride;
+ compileRequest.typeConformances.add(c);
+ }
return ShaderCompilerUtil::compileProgram(session, options, input, compileRequest, output.output);
}
diff --git a/tools/render-test/slang-support.h b/tools/render-test/slang-support.h
index da1f379fd..7770bada4 100644
--- a/tools/render-test/slang-support.h
+++ b/tools/render-test/slang-support.h
@@ -33,11 +33,20 @@ struct ShaderCompileRequest
SlangStage slangStage;
};
+ struct TypeConformance
+ {
+ public:
+ Slang::String derivedTypeName;
+ Slang::String baseTypeName;
+ Int idOverride;
+ };
+
SourceInfo source;
Slang::List<EntryPoint> entryPoints;
Slang::List<Slang::String> globalSpecializationArgs;
Slang::List<Slang::String> entryPointSpecializationArgs;
+ Slang::List<TypeConformance> typeConformances;
};