diff options
| author | Yong He <yonghe@outlook.com> | 2022-01-21 12:13:23 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-01-21 12:13:23 -0800 |
| commit | 7cff340b10b27f82781335093759bbdc19cd2865 (patch) | |
| tree | 15a983c4fdfb68b700ab532d5d2a684e8727e6b4 | |
| parent | f85bc7ae98486b37518958e659f659f1ff9b125c (diff) | |
Add entry-point name override feature. (#2089)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | slang.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 14 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 14 |
5 files changed, 69 insertions, 3 deletions
@@ -4273,6 +4273,8 @@ namespace slang struct IEntryPoint : public IComponentType { + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint) = 0; + SLANG_COM_INTERFACE(0x8f241361, 0xf5bd, 0x4ca0, { 0xa3, 0xac, 0x2, 0xf7, 0xfa, 0x24, 0x2, 0xb8 }) }; diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index c1a768f02..d90e9b102 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -282,6 +282,14 @@ namespace Slang return m_mangledName; } + String EntryPoint::getEntryPointNameOverride(Index index) + { + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + + return m_nameOverride; + } + void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo)); @@ -1211,7 +1219,11 @@ namespace Slang // Set the entry point name options.entryPointName = getText(entryPoint->getName()); - + auto entryPointNameOverride = program->getEntryPointNameOverride(entryPointIndex); + if (entryPointNameOverride.getLength() != 0) + { + options.entryPointName = entryPointNameOverride; + } if (compilerType == PassThroughMode::Dxc) { // We will enable the flag to generate proper code for 16 - bit types diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 52c03ffdb..930210a96 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -319,6 +319,9 @@ namespace Slang /// Get the mangled name of one of the entry points linked into this component type. virtual String getEntryPointMangledName(Index index) = 0; + /// Get the name override of one of the entry points linked into this component type. + virtual String getEntryPointNameOverride(Index index) = 0; + /// Get the number of global shader parameters linked into this component type. virtual Index getShaderParamCount() = 0; @@ -511,6 +514,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE; RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE; String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE; ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; @@ -560,6 +564,7 @@ namespace Slang // List<EntryPoint*> m_entryPoints; List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; List<ShaderParamInfo> m_shaderParams; List<SpecializationParam> m_specializationParams; List<ComponentType*> m_requirements; @@ -598,6 +603,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); } String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } @@ -638,6 +644,7 @@ namespace Slang RefPtr<IRModule> m_irModule; List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; // Any tagged union types that were referenced by the specialization arguments. List<TaggedUnionType*> m_taggedUnionTypes; @@ -721,6 +728,15 @@ namespace Slang return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint) + SLANG_OVERRIDE + { + RefPtr<EntryPoint> newEntryPoint = create(getLinkage(), m_funcDeclRef, m_profile); + newEntryPoint->m_nameOverride = newName; + *outEntryPoint = newEntryPoint.detach(); + return SLANG_OK; + } + /// Create an entry point that refers to the given function. static RefPtr<EntryPoint> create( Linkage* linkage, @@ -791,6 +807,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; } String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE { return 0; } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return ShaderParamInfo(); } @@ -831,6 +848,9 @@ namespace Slang /// The mangled name of the entry point function String m_mangledName; + /// The name of this entry point in the compiled code. + String m_nameOverride; + SpecializationParams m_genericSpecializationParams; SpecializationParams m_existentialSpecializationParams; @@ -940,6 +960,7 @@ namespace Slang return nullptr; } String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } Index getShaderParamCount() SLANG_OVERRIDE { return 0; } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE @@ -1107,6 +1128,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return 0; } RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } String getEntryPointMangledName(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 77998fba1..385b536dc 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -809,7 +809,8 @@ static void maybeCopyLayoutInformationToParameters( IRFunc* specializeIRForEntryPoint( IRSpecContext* context, - String const& mangledName) + String const& mangledName, + String const& nameOverride) { // We start by looking up the IR symbol that // matches the mangled name given to the @@ -848,6 +849,20 @@ IRFunc* specializeIRForEntryPoint( // auto clonedVal = cloneGlobalValue(context, originalVal); + if (nameOverride.getLength()) + { + if (auto entryPointDecor = clonedVal->findDecoration<IREntryPointDecoration>()) + { + IRInst* operands[] = { + entryPointDecor->getProfileInst(), + context->builder->getStringValue(nameOverride.getUnownedSlice()), + entryPointDecor->getModuleName()}; + context->builder->addDecoration( + clonedVal, IROp::kIROp_EntryPointDecoration, operands, 3); + entryPointDecor->removeAndDeallocate(); + } + } + // In the case where the user is requesting a specialization // of a generic entry point, we have a bit of a problem. // @@ -1425,7 +1440,8 @@ LinkedIR linkIR( for (auto entryPointIndex : entryPointIndices) { auto entryPointMangledName = program->getEntryPointMangledName(entryPointIndex); - irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName)); + auto nameOverride = program->getEntryPointNameOverride(entryPointIndex); + irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName, nameOverride)); } // Layout information for global shader parameters is also required, diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5fbc5b4bf..1333e3660 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3325,6 +3325,7 @@ CompositeComponentType::CompositeComponentType( { m_entryPoints.add(child->getEntryPoint(cc)); m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); + m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc)); } auto childShaderParamCount = child->getShaderParamCount(); @@ -3376,6 +3377,11 @@ String CompositeComponentType::getEntryPointMangledName(Index index) return m_entryPointMangledNames[index]; } +String CompositeComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + Index CompositeComponentType::getShaderParamCount() { return m_shaderParams.getCount(); @@ -3783,6 +3789,7 @@ SpecializedComponentType::SpecializedComponentType( struct EntryPointMangledNameCollector : ComponentTypeVisitor { List<String>* mangledEntryPointNames; + List<String>* entryPointNameOverrides; void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE { @@ -3791,6 +3798,7 @@ SpecializedComponentType::SpecializedComponentType( funcDeclRef = specializationInfo->specializedFuncDeclRef; (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); + (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); } void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE @@ -3815,6 +3823,7 @@ SpecializedComponentType::SpecializedComponentType( // EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); collector.mangledEntryPointNames = &m_entryPointMangledNames; + collector.entryPointNameOverrides = &m_entryPointNameOverrides; collector.visitSpecialized(this); } @@ -3840,6 +3849,11 @@ String SpecializedComponentType::getEntryPointMangledName(Index index) return m_entryPointMangledNames[index]; } +String SpecializedComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) { auto childCount = composite->getChildComponentCount(); |
