diff options
| -rw-r--r-- | slang.h | 33 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 86 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 126 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 64 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 47 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-16.slang | 10 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.cpp | 15 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.h | 10 | ||||
| -rw-r--r-- | tools/render-test/slang-support.cpp | 39 | ||||
| -rw-r--r-- | tools/render-test/slang-support.h | 9 |
14 files changed, 470 insertions, 7 deletions
@@ -3050,6 +3050,7 @@ namespace slang typedef ISlangBlob IBlob; struct IComponentType; + struct ITypeConformance; struct IGlobalSession; struct IModule; struct ISession; @@ -4023,6 +4024,32 @@ namespace slang */ virtual SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest( SlangCompileRequest** outCompileRequest) = 0; + + + /** Creates a `IComponentType` that represents a type's conformance to an interface. + The retrieved `ITypeConformance` objects can be included in a composite `IComponentType` + to explicitly specify which implementation types should be included in the final compiled + code. For example, if an module defines `IMaterial` interface and `AMaterial`, + `BMaterial`, `CMaterial` types that implements the interface, the user can exclude + `CMaterial` implementation from the resulting shader code by explcitly adding + `AMaterial:IMaterial` and `BMaterial:IMaterial` conformances to a composite + `IComponentType` and get entry point code from it. The resulting code will not have + anything related to `CMaterial` in the dynamic dispatch logic. If the user does not + explicitly include any `TypeConformances` to an interface type, all implementations to + that interface will be included by default. By linking a `ITypeConformance`, the user is + also given the opportunity to specify the dispatch ID of the implementation type. If + `conformanceIdOverride` is -1, there will be no override behavior and Slang will + automatically assign IDs to implementation types. The automatically assigned IDs can be + queried via `ISession::getTypeConformanceWitnessSequentialID`. + + Returns SLANG_OK if succeeds, or SLANG_FAIL if `type` does not conform to `interfaceType`. + */ + virtual SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + ITypeConformance** outConformance, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) = 0; }; #define SLANG_UUID_ISession ISession::getTypeGuid() @@ -4204,6 +4231,12 @@ namespace slang #define SLANG_UUID_IEntryPoint IEntryPoint::getTypeGuid() + struct ITypeConformance : public IComponentType + { + SLANG_COM_INTERFACE(0x73eb3147, 0xe544, 0x41b5, { 0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0xb }) + }; + #define SLANG_UUID_ITypeConformance ITypeConformance::getTypeGuid() + /** A module is the granularity of shader code compilation and loading. In most cases a module corresponds to a single compile "translation unit." diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 997f3fd51..5d0ae2f8d 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -296,6 +296,92 @@ namespace Slang return empty; } + TypeConformance::TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_subtypeWitness(witness) + , m_conformanceIdOverride(confomrmanceIdOverride) + { + addDepedencyFromWitness(witness); + m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); + } + + void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) + { + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + auto declModule = getModule(declaredWitness->declRef.getDecl()); + m_moduleDependency.addDependency(declModule); + m_pathDependency.addDependency(declModule); + if (m_requirementSet.Add(declModule)) + { + m_requirements.add(declModule); + } + // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. + } + else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) + { + addDepedencyFromWitness(transitiveWitness->midToSup); + addDepedencyFromWitness(transitiveWitness->subToMid); + } + else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) + { + auto left = as<SubtypeWitness>(conjunctionWitness->leftWitness); + if (left) + addDepedencyFromWitness(left); + auto right = as<SubtypeWitness>(conjunctionWitness->rightWitness); + if (right) + addDepedencyFromWitness(right); + } + } + + ISlangUnknown* TypeConformance::getInterface(const Guid& guid) + { + if (guid == slang::ITypeConformance::getTypeGuid()) + return static_cast<slang::ITypeConformance*>(this); + + return Super::getInterface(guid); + } + + List<Module*> const& TypeConformance::getModuleDependencies() + { + return m_moduleDependency.getModuleList(); + } + + List<String> const& TypeConformance::getFilePathDependencies() + { + return m_pathDependency.getFilePathList(); + } + + Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); } + + RefPtr<ComponentType> TypeConformance::getRequirement(Index index) + { + return m_requirements[index]; + } + + void TypeConformance::acceptVisitor( + ComponentTypeVisitor* visitor, + ComponentType::SpecializationInfo* specializationInfo) + { + SLANG_UNUSED(specializationInfo); + visitor->visitTypeConformance(this); + } + + RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } + // Profile Profile::lookUp(UnownedStringSlice const& name) diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 1724794cf..727d21c3a 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -850,6 +850,125 @@ namespace Slang Profile m_profile; }; + class TypeConformance + : public ComponentType + , public slang::ITypeConformance + { + typedef ComponentType Super; + + public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Guid& guid); + + TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink); + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + } + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; + List<String> const& getFilePathDependencies() SLANG_OVERRIDE; + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE + { + static SpecializationParam emptyParam; + return emptyParam; + } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } + + SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } + IRModule* getIRModule() { return m_irModule.Ptr(); } + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + private: + SubtypeWitness* m_subtypeWitness; + ModuleDependencyList m_moduleDependency; + FilePathDependencyList m_pathDependency; + List<RefPtr<Module>> m_requirements; + HashSet<Module*> m_requirementSet; + RefPtr<IRModule> m_irModule; + Int m_conformanceIdOverride; + void addDepedencyFromWitness(SubtypeWitness* witness); + }; + enum class PassThroughMode : SlangPassThroughIntegral { None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler @@ -1319,6 +1438,12 @@ namespace Slang slang::TypeReflection* type, slang::TypeReflection* interfaceType, uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformance, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) override; SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest( SlangCompileRequest** outCompileRequest) override; @@ -1756,6 +1881,7 @@ namespace Slang virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0; virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitTypeConformance(TypeConformance* conformance) = 0; protected: // These helpers can be used to recurse into the logical children of a diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 480fe504c..36985efd6 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -274,7 +274,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) default: { - // In the deafult case, assume that we have some sort of "hoistable" + // In the default case, assume that we have some sort of "hoistable" // instruction that requires us to create a clone of it. UInt argCount = originalValue->getOperandCount(); @@ -439,6 +439,8 @@ static void cloneExtraDecorations( case kIROp_BindExistentialSlotsDecoration: case kIROp_LayoutDecoration: + case kIROp_PublicDecoration: + case kIROp_SequentialIDDecoration: if(!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index e9300ae06..9522dbeca 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -2158,6 +2158,11 @@ IRInst* specializeGenericImpl( if( auto returnValInst = as<IRReturnVal>(ii) ) { auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); + + // Clone decorations on the orignal `specialize` inst over to the newly specialized + // value. + cloneInstDecorationsAndChildren( + &env, &sharedBuilderStorage, specializeInst, specializedVal); return specializedVal; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c7e32072e..8a40dbde9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8255,6 +8255,11 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; RefPtr<IRModule> generateIRForSpecializedComponentType( @@ -8265,6 +8270,65 @@ RefPtr<IRModule> generateIRForSpecializedComponentType( return context.process(componentType, sink); } + /// Context for generating IR code to represent a `TypeConformance` +struct TypeConformanceIRGenContext +{ + DiagnosticSink* sink; + Linkage* linkage; + Session* session; + IRGenContext* context; + IRBuilder* builder; + + RefPtr<IRModule> process( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* inSink) + { + sink = inSink; + + linkage = typeConformance->getLinkage(); + session = linkage->getSessionImpl(); + + SharedIRGenContext sharedContextStorage(session, sink, linkage->m_obfuscateCode); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext, linkage->getASTBuilder()); + context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder builderStorage; + builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + RefPtr<IRModule> module = builder->createModule(); + sharedBuilder->module = module; + + builder->setInsertInto(module->getModuleInst()); + + context->irBuilder = builder; + + auto witness = lowerSimpleVal(context, typeConformance->getSubtypeWitness()); + builder->addPublicDecoration(witness); + if (conformanceIdOverride != -1) + { + builder->addSequentialIDDecoration(witness, conformanceIdOverride); + } + return module; + } +}; + +RefPtr<IRModule> generateIRForTypeConformance( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* sink) +{ + TypeConformanceIRGenContext context; + return context.process(typeConformance, conformanceIdOverride, sink); +} RefPtr<IRModule> TargetProgram::getOrCreateIRModuleForLayout(DiagnosticSink* sink) { diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h index dbf2e550a..ce7f9eaf0 100644 --- a/source/slang/slang-lower-to-ir.h +++ b/source/slang/slang-lower-to-ir.h @@ -43,5 +43,14 @@ namespace Slang RefPtr<IRModule> generateIRForSpecializedComponentType( SpecializedComponentType* componentType, DiagnosticSink* sink); -} + + /// Generate an IR module to represent a user specified `TypeConformance` component type. + /// The generated IR will include an extern symbol representing the type conformance + /// (typically a `IRWitnessTable` or a `specialize(IRWitnessTable)` inst), with a `public` + /// decoration to keep the referenced witness table alive during linking. + RefPtr<IRModule> generateIRForTypeConformance( + TypeConformance* typeConformance, + Int conformanceIdOverride, + DiagnosticSink* sink); + } #endif diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 9b080a11c..6f9f127f8 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2735,6 +2735,11 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor SLANG_UNUSED(specializationInfo); } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(module); @@ -2920,6 +2925,10 @@ struct CollectParametersVisitor : ComponentTypeVisitor } } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; /// Recursively collect the global shader parameters and entry points in `program`. @@ -3146,6 +3155,11 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor auto base = specialized->getBaseComponentType(); _completeBindings(m_context, base, m_counters); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; /// A visitor used by `_completeBindings`. @@ -3272,6 +3286,10 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor m_counters->entryPointCounter += specialized->getEntryPointCount(); } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; static void _completeBindings( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index ef6872e76..e62d2c9ae 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1043,6 +1043,33 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent return SLANG_OK; } +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformanceComponentType, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) +{ + RefPtr<TypeConformance> result; + DiagnosticSink sink; + try + { + SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + auto witness = + visitor.tryGetSubtypeWitness((Slang::Type*)type, (Slang::Type*)interfaceType); + if (auto subtypeWitness = as<SubtypeWitness>(witness)) + { + result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink); + } + } + catch (...) + {} + sink.getBlobIfNeeded(outDiagnostics); + *outConformanceComponentType = result.detach(); + return result ? SLANG_OK : SLANG_FAIL; +} + SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest( SlangCompileRequest** outCompileRequest) { @@ -3041,6 +3068,11 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; @@ -3079,6 +3111,11 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor m_callback(specialized->getIRModule(), m_userData); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + m_callback(conformance->getIRModule(), m_userData); + } }; void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) @@ -3401,6 +3438,11 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor { visitChildren(specialized); } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } }; SpecializedComponentType::SpecializedComponentType( @@ -3616,7 +3658,10 @@ SpecializedComponentType::SpecializedComponentType( { visitChildren(composite, specializationInfo); } void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE { visitChildren(specialized); } - + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } EntryPointMangledNameCollector(ASTBuilder* astBuilder): m_astBuilder(astBuilder) { diff --git a/tests/compute/dynamic-dispatch-16.slang b/tests/compute/dynamic-dispatch-16.slang index 5ceb75cd7..b9fecb966 100644 --- a/tests/compute/dynamic-dispatch-16.slang +++ b/tests/compute/dynamic-dispatch-16.slang @@ -19,9 +19,12 @@ struct UserDefinedPackedType //TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer RWStructuredBuffer<float> gOutputBuffer; -//TEST_INPUT: set gObj = new StructuredBuffer<UserDefinedPackedType>[new UserDefinedPackedType{[1.0, 0.0, 0.0], 0}, new UserDefinedPackedType{[2.0, 3.0, 4.0], 1}]; +//TEST_INPUT: set gObj = new StructuredBuffer<UserDefinedPackedType>[new UserDefinedPackedType{[1.0, 2.0, 3.0], 3}, new UserDefinedPackedType{[2.0, 3.0, 4.0], 4}]; RWStructuredBuffer<UserDefinedPackedType> gObj; +//TEST_INPUT: type_conformance FloatVal:IInterface = 3 +//TEST_INPUT: type_conformance Float4Val:IInterface = 4 + [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { @@ -35,8 +38,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) gOutputBuffer[0] = result; } -// Type must be marked `public` to ensure it is visible in the generated DLL. -public struct FloatVal : IInterface +struct FloatVal : IInterface { float val; float run() @@ -46,7 +48,7 @@ public struct FloatVal : IInterface }; interface ISomething{void g();} struct Float4Struct : ISomething { float4 val; void g() {} } -public struct Float4Val : IInterface +struct Float4Val : IInterface { Float4Struct val; float run() diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 2e0741cea..7e6d290c8 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -837,6 +837,17 @@ namespace renderer_test parentForNewVal->addField(field); } + void parseTypeConformance(Misc::TokenReader& parser) + { + ShaderInputLayout::TypeConformanceVal conformance; + conformance.derivedTypeName = parseTypeName(parser); + parser.Read(":"); + conformance.baseTypeName = parseTypeName(parser); + if (parser.AdvanceIf("=")) + conformance.idOverride = parser.ReadInt(); + layout->typeConformances.add(conformance); + } + void parseLine(Misc::TokenReader& parser) { if (parser.LookAhead("entryPointSpecializationArg") @@ -872,6 +883,10 @@ namespace renderer_test { parseSetEntry(parser); } + else if (parser.AdvanceIf("type_conformance")) + { + parseTypeConformance(parser); + } else { parseValEntry(parser); diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index 76ebe79b3..fe835a7f1 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -285,6 +285,16 @@ public: Slang::RefPtr<AggVal> rootVal; Slang::List<Slang::String> globalSpecializationArgs; Slang::List<Slang::String> entryPointSpecializationArgs; + + class TypeConformanceVal + { + public: + Slang::String derivedTypeName; + Slang::String baseTypeName; + Int idOverride = -1; + }; + Slang::List<TypeConformanceVal> typeConformances; + int numRenderTargets = 1; Slang::Index findEntryIndexByName(const Slang::String& name) const; diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 65f8f244b..f479218f7 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -235,6 +235,37 @@ void ShaderCompilerUtil::Output::reset() actualEntryPoints = request.entryPoints; } + if (request.typeConformances.getCount()) + { + ComPtr<slang::ISession> session; + slangRequest->getSession(session.writeRef()); + List<ComPtr<slang::ITypeConformance>> typeConformanceComponents; + List<slang::IComponentType*> componentsRawPtr; + componentsRawPtr.add(linkedSlangProgram.get()); + auto reflection = slang::ProgramLayout::get(slangRequest); + ComPtr<ISlangBlob> outDiagnostic; + for (auto& conformance : request.typeConformances) + { + auto derivedType = reflection->findTypeByName(conformance.derivedTypeName.getBuffer()); + auto baseType = reflection->findTypeByName(conformance.baseTypeName.getBuffer()); + ComPtr<slang::ITypeConformance> conformanceComponentType; + session->createTypeConformanceComponentType( + derivedType, + baseType, + conformanceComponentType.writeRef(), + conformance.idOverride, + outDiagnostic.writeRef()); + typeConformanceComponents.add(conformanceComponentType); + componentsRawPtr.add(conformanceComponentType); + } + ComPtr<slang::IComponentType> newProgram; + session->createCompositeComponentType( + componentsRawPtr.getBuffer(), + componentsRawPtr.getCount(), + newProgram.writeRef(), + outDiagnostic.writeRef()); + linkedSlangProgram = newProgram; + } out.set(input.pipelineType, linkedSlangProgram); return SLANG_OK; } @@ -415,6 +446,14 @@ void ShaderCompilerUtil::Output::reset() } compileRequest.globalSpecializationArgs = layout.globalSpecializationArgs; compileRequest.entryPointSpecializationArgs = layout.entryPointSpecializationArgs; + for (auto conformance : layout.typeConformances) + { + ShaderCompileRequest::TypeConformance c; + c.derivedTypeName = conformance.derivedTypeName; + c.baseTypeName = conformance.baseTypeName; + c.idOverride = conformance.idOverride; + compileRequest.typeConformances.add(c); + } return ShaderCompilerUtil::compileProgram(session, options, input, compileRequest, output.output); } diff --git a/tools/render-test/slang-support.h b/tools/render-test/slang-support.h index da1f379fd..7770bada4 100644 --- a/tools/render-test/slang-support.h +++ b/tools/render-test/slang-support.h @@ -33,11 +33,20 @@ struct ShaderCompileRequest SlangStage slangStage; }; + struct TypeConformance + { + public: + Slang::String derivedTypeName; + Slang::String baseTypeName; + Int idOverride; + }; + SourceInfo source; Slang::List<EntryPoint> entryPoints; Slang::List<Slang::String> globalSpecializationArgs; Slang::List<Slang::String> entryPointSpecializationArgs; + Slang::List<TypeConformance> typeConformances; }; |
