diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 2 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 160 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 56 |
5 files changed, 239 insertions, 14 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index d90e9b102..0e2f339a9 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -287,7 +287,7 @@ namespace Slang SLANG_UNUSED(index); SLANG_ASSERT(index == 0); - return m_nameOverride; + return m_name ? m_name->text : ""; } void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 930210a96..9ae414424 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -293,6 +293,9 @@ namespace Slang SlangInt specializationArgCount, slang::IComponentType** outSpecializedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, + slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; SLANG_NO_THROW SlangResult SLANG_MCALL link( slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE; @@ -654,6 +657,130 @@ namespace Slang List<RefPtr<ComponentType>> m_requirements; }; + class RenamedEntryPointComponentType : public ComponentType + { + public: + using Super = ComponentType; + + RenamedEntryPointComponentType(ComponentType* base, String newName); + + ComponentType* getBase() { return m_base.Ptr(); } + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + } + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE + { + return m_base->getModuleDependencies(); + } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE + { + return m_base->getFilePathDependencies(); + } + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE + { + return m_base->getSpecializationParamCount(); + } + + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + return m_base->getSpecializationParam(index); + } + + Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); } + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE + { + return m_base->getRequirement(index); + } + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPoint(index); + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE { return m_base->getEntryPointMangledName(index); } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + return m_entryPointNameOverride; + } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + return m_base->getShaderParam(index); + } + + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + private: + RefPtr<ComponentType> m_base; + String m_entryPointNameOverride; + + protected: + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, Index argCount, DiagnosticSink* sink) SLANG_OVERRIDE + { + return m_base->_validateSpecializationArgsImpl(args, argCount, sink); + } + }; + /// Describes an entry point for the purposes of layout and code generation. /// /// This class also tracks any generic arguments to the entry point, @@ -710,6 +837,12 @@ namespace Slang outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + SLANG_NO_THROW SlangResult SLANG_MCALL link( slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE @@ -728,15 +861,6 @@ namespace Slang return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); } - SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint) - SLANG_OVERRIDE - { - RefPtr<EntryPoint> newEntryPoint = create(getLinkage(), m_funcDeclRef, m_profile); - newEntryPoint->m_nameOverride = newName; - *outEntryPoint = newEntryPoint.detach(); - return SLANG_OK; - } - /// Create an entry point that refers to the given function. static RefPtr<EntryPoint> create( Linkage* linkage, @@ -848,9 +972,6 @@ namespace Slang /// The mangled name of the entry point function String m_mangledName; - /// The name of this entry point in the compiled code. - String m_nameOverride; - SpecializationParams m_genericSpecializationParams; SpecializationParams m_existentialSpecializationParams; @@ -922,6 +1043,12 @@ namespace Slang outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + SLANG_NO_THROW SlangResult SLANG_MCALL link( slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE @@ -1060,6 +1187,12 @@ namespace Slang outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + SLANG_NO_THROW SlangResult SLANG_MCALL link( slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE @@ -1919,6 +2052,9 @@ namespace Slang virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; virtual void visitTypeConformance(TypeConformance* conformance) = 0; + virtual void visitRenamedEntryPoint( + RenamedEntryPointComponentType* renamedEntryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; protected: // These helpers can be used to recurse into the logical children of a diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b47448ae1..d4b069dca 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8242,6 +8242,13 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor lowerProgramEntryPointToIR(context, entryPoint, specializationInfo); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { // We've hit a leaf module, so we should be able to bind any global diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 6f9f127f8..48e637047 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2729,6 +2729,13 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor ParameterBindingContext* m_context; + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(entryPoint); @@ -2889,6 +2896,13 @@ struct CollectParametersVisitor : ComponentTypeVisitor collectEntryPointParameters(context, entryPoint, specializationInfo); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* renamedEntryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + renamedEntryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { // A single module represents a leaf case for layout. @@ -3107,6 +3121,13 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor completeBindingsForParameter(m_context, globalEntryPointInfo->parametersLayout); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* renamedEntryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + renamedEntryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(specializationInfo); @@ -3236,6 +3257,13 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor _allocateBindingsForPendingData(m_context, globalEntryPointInfo->parametersLayout->pendingVarLayout); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(specializationInfo); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 1333e3660..3880d1114 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3160,6 +3160,15 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( return SLANG_OK; } +SLANG_NO_THROW SlangResult SLANG_MCALL + ComponentType::renameEntryPoint(const char* newName, IComponentType** outEntryPoint) +{ + RefPtr<RenamedEntryPointComponentType> result = + new RenamedEntryPointComponentType(this, newName); + *outEntryPoint = result.detach(); + return SLANG_OK; +} + RefPtr<ComponentType> fillRequirements( ComponentType* inComponentType); @@ -3195,6 +3204,13 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE { m_callback(module, m_userData); @@ -3236,6 +3252,13 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE { m_callback(module->getIRModule(), m_userData); @@ -3427,7 +3450,6 @@ void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, Specia visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); } - RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, @@ -3562,6 +3584,13 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor collectReferencedModules(specializationInfo->existentialSpecializationArgs); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { SLANG_UNUSED(module); @@ -3801,6 +3830,14 @@ SpecializedComponentType::SpecializedComponentType( (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); } + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + (*entryPointNameOverrides).getLast() = entryPoint->getEntryPointNameOverride(0); + } + void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE {} void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE @@ -3854,6 +3891,23 @@ String SpecializedComponentType::getEntryPointNameOverride(Index index) return m_entryPointNameOverrides[index]; } +// RenamedEntryPointComponentType + +RenamedEntryPointComponentType::RenamedEntryPointComponentType( + ComponentType* base, String newName) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_entryPointNameOverride(newName) +{ +} + +void RenamedEntryPointComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitRenamedEntryPoint( + this, as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo)); +} + void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) { auto childCount = composite->getChildComponentCount(); |
