summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-01-21 12:13:23 -0800
committerGitHub <noreply@github.com>2022-01-21 12:13:23 -0800
commit7cff340b10b27f82781335093759bbdc19cd2865 (patch)
tree15a983c4fdfb68b700ab532d5d2a684e8727e6b4
parentf85bc7ae98486b37518958e659f659f1ff9b125c (diff)
Add entry-point name override feature. (#2089)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--slang.h2
-rw-r--r--source/slang/slang-compiler.cpp14
-rwxr-xr-xsource/slang/slang-compiler.h22
-rw-r--r--source/slang/slang-ir-link.cpp20
-rw-r--r--source/slang/slang.cpp14
5 files changed, 69 insertions, 3 deletions
diff --git a/slang.h b/slang.h
index d47617edf..86a752f01 100644
--- a/slang.h
+++ b/slang.h
@@ -4273,6 +4273,8 @@ namespace slang
struct IEntryPoint : public IComponentType
{
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint) = 0;
+
SLANG_COM_INTERFACE(0x8f241361, 0xf5bd, 0x4ca0, { 0xa3, 0xac, 0x2, 0xf7, 0xfa, 0x24, 0x2, 0xb8 })
};
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index c1a768f02..d90e9b102 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -282,6 +282,14 @@ namespace Slang
return m_mangledName;
}
+ String EntryPoint::getEntryPointNameOverride(Index index)
+ {
+ SLANG_UNUSED(index);
+ SLANG_ASSERT(index == 0);
+
+ return m_nameOverride;
+ }
+
void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
{
visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo));
@@ -1211,7 +1219,11 @@ namespace Slang
// Set the entry point name
options.entryPointName = getText(entryPoint->getName());
-
+ auto entryPointNameOverride = program->getEntryPointNameOverride(entryPointIndex);
+ if (entryPointNameOverride.getLength() != 0)
+ {
+ options.entryPointName = entryPointNameOverride;
+ }
if (compilerType == PassThroughMode::Dxc)
{
// We will enable the flag to generate proper code for 16 - bit types
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 52c03ffdb..930210a96 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -319,6 +319,9 @@ namespace Slang
/// Get the mangled name of one of the entry points linked into this component type.
virtual String getEntryPointMangledName(Index index) = 0;
+ /// Get the name override of one of the entry points linked into this component type.
+ virtual String getEntryPointNameOverride(Index index) = 0;
+
/// Get the number of global shader parameters linked into this component type.
virtual Index getShaderParamCount() = 0;
@@ -511,6 +514,7 @@ namespace Slang
Index getEntryPointCount() SLANG_OVERRIDE;
RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE;
String getEntryPointMangledName(Index index) SLANG_OVERRIDE;
+ String getEntryPointNameOverride(Index index) SLANG_OVERRIDE;
Index getShaderParamCount() SLANG_OVERRIDE;
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE;
@@ -560,6 +564,7 @@ namespace Slang
//
List<EntryPoint*> m_entryPoints;
List<String> m_entryPointMangledNames;
+ List<String> m_entryPointNameOverrides;
List<ShaderParamInfo> m_shaderParams;
List<SpecializationParam> m_specializationParams;
List<ComponentType*> m_requirements;
@@ -598,6 +603,7 @@ namespace Slang
Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); }
RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); }
String getEntryPointMangledName(Index index) SLANG_OVERRIDE;
+ String getEntryPointNameOverride(Index index) SLANG_OVERRIDE;
Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); }
@@ -638,6 +644,7 @@ namespace Slang
RefPtr<IRModule> m_irModule;
List<String> m_entryPointMangledNames;
+ List<String> m_entryPointNameOverrides;
// Any tagged union types that were referenced by the specialization arguments.
List<TaggedUnionType*> m_taggedUnionTypes;
@@ -721,6 +728,15 @@ namespace Slang
return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
}
+ SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint)
+ SLANG_OVERRIDE
+ {
+ RefPtr<EntryPoint> newEntryPoint = create(getLinkage(), m_funcDeclRef, m_profile);
+ newEntryPoint->m_nameOverride = newName;
+ *outEntryPoint = newEntryPoint.detach();
+ return SLANG_OK;
+ }
+
/// Create an entry point that refers to the given function.
static RefPtr<EntryPoint> create(
Linkage* linkage,
@@ -791,6 +807,7 @@ namespace Slang
Index getEntryPointCount() SLANG_OVERRIDE { return 1; };
RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; }
String getEntryPointMangledName(Index index) SLANG_OVERRIDE;
+ String getEntryPointNameOverride(Index index) SLANG_OVERRIDE;
Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return ShaderParamInfo(); }
@@ -831,6 +848,9 @@ namespace Slang
/// The mangled name of the entry point function
String m_mangledName;
+ /// The name of this entry point in the compiled code.
+ String m_nameOverride;
+
SpecializationParams m_genericSpecializationParams;
SpecializationParams m_existentialSpecializationParams;
@@ -940,6 +960,7 @@ namespace Slang
return nullptr;
}
String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; }
+ String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; }
Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE
@@ -1107,6 +1128,7 @@ namespace Slang
Index getEntryPointCount() SLANG_OVERRIDE { return 0; }
RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; }
String getEntryPointMangledName(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); }
+ String getEntryPointNameOverride(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); }
Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; }
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 77998fba1..385b536dc 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -809,7 +809,8 @@ static void maybeCopyLayoutInformationToParameters(
IRFunc* specializeIRForEntryPoint(
IRSpecContext* context,
- String const& mangledName)
+ String const& mangledName,
+ String const& nameOverride)
{
// We start by looking up the IR symbol that
// matches the mangled name given to the
@@ -848,6 +849,20 @@ IRFunc* specializeIRForEntryPoint(
//
auto clonedVal = cloneGlobalValue(context, originalVal);
+ if (nameOverride.getLength())
+ {
+ if (auto entryPointDecor = clonedVal->findDecoration<IREntryPointDecoration>())
+ {
+ IRInst* operands[] = {
+ entryPointDecor->getProfileInst(),
+ context->builder->getStringValue(nameOverride.getUnownedSlice()),
+ entryPointDecor->getModuleName()};
+ context->builder->addDecoration(
+ clonedVal, IROp::kIROp_EntryPointDecoration, operands, 3);
+ entryPointDecor->removeAndDeallocate();
+ }
+ }
+
// In the case where the user is requesting a specialization
// of a generic entry point, we have a bit of a problem.
//
@@ -1425,7 +1440,8 @@ LinkedIR linkIR(
for (auto entryPointIndex : entryPointIndices)
{
auto entryPointMangledName = program->getEntryPointMangledName(entryPointIndex);
- irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName));
+ auto nameOverride = program->getEntryPointNameOverride(entryPointIndex);
+ irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName, nameOverride));
}
// Layout information for global shader parameters is also required,
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 5fbc5b4bf..1333e3660 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -3325,6 +3325,7 @@ CompositeComponentType::CompositeComponentType(
{
m_entryPoints.add(child->getEntryPoint(cc));
m_entryPointMangledNames.add(child->getEntryPointMangledName(cc));
+ m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc));
}
auto childShaderParamCount = child->getShaderParamCount();
@@ -3376,6 +3377,11 @@ String CompositeComponentType::getEntryPointMangledName(Index index)
return m_entryPointMangledNames[index];
}
+String CompositeComponentType::getEntryPointNameOverride(Index index)
+{
+ return m_entryPointNameOverrides[index];
+}
+
Index CompositeComponentType::getShaderParamCount()
{
return m_shaderParams.getCount();
@@ -3783,6 +3789,7 @@ SpecializedComponentType::SpecializedComponentType(
struct EntryPointMangledNameCollector : ComponentTypeVisitor
{
List<String>* mangledEntryPointNames;
+ List<String>* entryPointNameOverrides;
void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
@@ -3791,6 +3798,7 @@ SpecializedComponentType::SpecializedComponentType(
funcDeclRef = specializationInfo->specializedFuncDeclRef;
(*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef));
+ (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0));
}
void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
@@ -3815,6 +3823,7 @@ SpecializedComponentType::SpecializedComponentType(
//
EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder());
collector.mangledEntryPointNames = &m_entryPointMangledNames;
+ collector.entryPointNameOverrides = &m_entryPointNameOverrides;
collector.visitSpecialized(this);
}
@@ -3840,6 +3849,11 @@ String SpecializedComponentType::getEntryPointMangledName(Index index)
return m_entryPointMangledNames[index];
}
+String SpecializedComponentType::getEntryPointNameOverride(Index index)
+{
+ return m_entryPointNameOverrides[index];
+}
+
void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
{
auto childCount = composite->getChildComponentCount();