summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-01-31 16:26:03 -0800
committerGitHub <noreply@github.com>2022-01-31 16:26:03 -0800
commite59516fa8c3a16eb7b99a928c5b85b97bf44fd72 (patch)
tree869c2b8df0cc0d368af928324d53079a9f7999e0 /source
parent2bb43bbe4709533e0c6e53df1c62d368132dcd73 (diff)
Revise entrypoint renaming interface. (#2113)
Changed the interface from `IEntryPoint::getRenamedEntryPoint` to `IComponentType::renameEntryPoint`. The underlying implementation creates a `RenamedEntryPointComponentType` wrapper object around the base entry-point. This new implementation allows the user to specify entry point renaming on an IComponentType that isn't just a `EntryPoint`, but also on `SpecializedComponentType` or `CompositeComponentType` as long as the component defines a single entry point. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-compiler.cpp2
-rwxr-xr-xsource/slang/slang-compiler.h160
-rw-r--r--source/slang/slang-lower-to-ir.cpp7
-rw-r--r--source/slang/slang-parameter-binding.cpp28
-rw-r--r--source/slang/slang.cpp56
5 files changed, 239 insertions, 14 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index d90e9b102..0e2f339a9 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -287,7 +287,7 @@ namespace Slang
SLANG_UNUSED(index);
SLANG_ASSERT(index == 0);
- return m_nameOverride;
+ return m_name ? m_name->text : "";
}
void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 930210a96..9ae414424 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -293,6 +293,9 @@ namespace Slang
SlangInt specializationArgCount,
slang::IComponentType** outSpecializedComponentType,
ISlangBlob** outDiagnostics) SLANG_OVERRIDE;
+ SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName,
+ slang::IComponentType** outEntryPoint) SLANG_OVERRIDE;
SLANG_NO_THROW SlangResult SLANG_MCALL link(
slang::IComponentType** outLinkedComponentType,
ISlangBlob** outDiagnostics) SLANG_OVERRIDE;
@@ -654,6 +657,130 @@ namespace Slang
List<RefPtr<ComponentType>> m_requirements;
};
+ class RenamedEntryPointComponentType : public ComponentType
+ {
+ public:
+ using Super = ComponentType;
+
+ RenamedEntryPointComponentType(ComponentType* base, String newName);
+
+ ComponentType* getBase() { return m_base.Ptr(); }
+
+ // Forward `IComponentType` methods
+
+ SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE
+ {
+ return Super::getSession();
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL
+ getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getLayout(targetIndex, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::specialize(
+ specializationArgs,
+ specializationArgCount,
+ outSpecializedComponentType,
+ outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE
+ {
+ return Super::renameEntryPoint(newName, outEntryPoint);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::link(outLinkedComponentType, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointHostCallable(
+ entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ }
+
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE
+ {
+ return m_base->getModuleDependencies();
+ }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE
+ {
+ return m_base->getFilePathDependencies();
+ }
+
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE
+ {
+ return m_base->getSpecializationParamCount();
+ }
+
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE
+ {
+ return m_base->getSpecializationParam(index);
+ }
+
+ Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); }
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE
+ {
+ return m_base->getRequirement(index);
+ }
+ Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); }
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE
+ {
+ return m_base->getEntryPoint(index);
+ }
+ String getEntryPointMangledName(Index index) SLANG_OVERRIDE { return m_base->getEntryPointMangledName(index); }
+ String getEntryPointNameOverride(Index index) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(index);
+ SLANG_ASSERT(index == 0);
+ return m_entryPointNameOverride;
+ }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); }
+ ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE
+ {
+ return m_base->getShaderParam(index);
+ }
+
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+ SLANG_OVERRIDE;
+ private:
+ RefPtr<ComponentType> m_base;
+ String m_entryPointNameOverride;
+
+ protected:
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args, Index argCount, DiagnosticSink* sink) SLANG_OVERRIDE
+ {
+ return m_base->_validateSpecializationArgsImpl(args, argCount, sink);
+ }
+ };
+
/// Describes an entry point for the purposes of layout and code generation.
///
/// This class also tracks any generic arguments to the entry point,
@@ -710,6 +837,12 @@ namespace Slang
outDiagnostics);
}
+ SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE
+ {
+ return Super::renameEntryPoint(newName, outEntryPoint);
+ }
+
SLANG_NO_THROW SlangResult SLANG_MCALL link(
slang::IComponentType** outLinkedComponentType,
ISlangBlob** outDiagnostics) SLANG_OVERRIDE
@@ -728,15 +861,6 @@ namespace Slang
return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
}
- SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint)
- SLANG_OVERRIDE
- {
- RefPtr<EntryPoint> newEntryPoint = create(getLinkage(), m_funcDeclRef, m_profile);
- newEntryPoint->m_nameOverride = newName;
- *outEntryPoint = newEntryPoint.detach();
- return SLANG_OK;
- }
-
/// Create an entry point that refers to the given function.
static RefPtr<EntryPoint> create(
Linkage* linkage,
@@ -848,9 +972,6 @@ namespace Slang
/// The mangled name of the entry point function
String m_mangledName;
- /// The name of this entry point in the compiled code.
- String m_nameOverride;
-
SpecializationParams m_genericSpecializationParams;
SpecializationParams m_existentialSpecializationParams;
@@ -922,6 +1043,12 @@ namespace Slang
outDiagnostics);
}
+ SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE
+ {
+ return Super::renameEntryPoint(newName, outEntryPoint);
+ }
+
SLANG_NO_THROW SlangResult SLANG_MCALL link(
slang::IComponentType** outLinkedComponentType,
ISlangBlob** outDiagnostics) SLANG_OVERRIDE
@@ -1060,6 +1187,12 @@ namespace Slang
outDiagnostics);
}
+ SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE
+ {
+ return Super::renameEntryPoint(newName, outEntryPoint);
+ }
+
SLANG_NO_THROW SlangResult SLANG_MCALL link(
slang::IComponentType** outLinkedComponentType,
ISlangBlob** outDiagnostics) SLANG_OVERRIDE
@@ -1919,6 +2052,9 @@ namespace Slang
virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
virtual void visitSpecialized(SpecializedComponentType* specialized) = 0;
virtual void visitTypeConformance(TypeConformance* conformance) = 0;
+ virtual void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* renamedEntryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0;
protected:
// These helpers can be used to recurse into the logical children of a
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b47448ae1..d4b069dca 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8242,6 +8242,13 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
lowerProgramEntryPointToIR(context, entryPoint, specializationInfo);
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
// We've hit a leaf module, so we should be able to bind any global
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 6f9f127f8..48e637047 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2729,6 +2729,13 @@ struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor
ParameterBindingContext* m_context;
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(entryPoint);
@@ -2889,6 +2896,13 @@ struct CollectParametersVisitor : ComponentTypeVisitor
collectEntryPointParameters(context, entryPoint, specializationInfo);
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* renamedEntryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ renamedEntryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
// A single module represents a leaf case for layout.
@@ -3107,6 +3121,13 @@ struct CompleteBindingsVisitor : ComponentTypeVisitor
completeBindingsForParameter(m_context, globalEntryPointInfo->parametersLayout);
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* renamedEntryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ renamedEntryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(specializationInfo);
@@ -3236,6 +3257,13 @@ struct FlushPendingDataVisitor : ComponentTypeVisitor
_allocateBindingsForPendingData(m_context, globalEntryPointInfo->parametersLayout->pendingVarLayout);
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(specializationInfo);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 1333e3660..3880d1114 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -3160,6 +3160,15 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
return SLANG_OK;
}
+SLANG_NO_THROW SlangResult SLANG_MCALL
+ ComponentType::renameEntryPoint(const char* newName, IComponentType** outEntryPoint)
+{
+ RefPtr<RenamedEntryPointComponentType> result =
+ new RenamedEntryPointComponentType(this, newName);
+ *outEntryPoint = result.detach();
+ return SLANG_OK;
+}
+
RefPtr<ComponentType> fillRequirements(
ComponentType* inComponentType);
@@ -3195,6 +3204,13 @@ struct EnumerateModulesVisitor : ComponentTypeVisitor
void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
{
m_callback(module, m_userData);
@@ -3236,6 +3252,13 @@ struct EnumerateIRModulesVisitor : ComponentTypeVisitor
void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
{
m_callback(module->getIRModule(), m_userData);
@@ -3427,7 +3450,6 @@ void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, Specia
visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo));
}
-
RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
@@ -3562,6 +3584,13 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
collectReferencedModules(specializationInfo->existentialSpecializationArgs);
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ }
+
void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
SLANG_UNUSED(module);
@@ -3801,6 +3830,14 @@ SpecializedComponentType::SpecializedComponentType(
(*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0));
}
+ void visitRenamedEntryPoint(
+ RenamedEntryPointComponentType* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ entryPoint->getBase()->acceptVisitor(this, specializationInfo);
+ (*entryPointNameOverrides).getLast() = entryPoint->getEntryPointNameOverride(0);
+ }
+
void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
{}
void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
@@ -3854,6 +3891,23 @@ String SpecializedComponentType::getEntryPointNameOverride(Index index)
return m_entryPointNameOverrides[index];
}
+// RenamedEntryPointComponentType
+
+RenamedEntryPointComponentType::RenamedEntryPointComponentType(
+ ComponentType* base, String newName)
+ : ComponentType(base->getLinkage())
+ , m_base(base)
+ , m_entryPointNameOverride(newName)
+{
+}
+
+void RenamedEntryPointComponentType::acceptVisitor(
+ ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+{
+ visitor->visitRenamedEntryPoint(
+ this, as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo));
+}
+
void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
{
auto childCount = composite->getChildComponentCount();