diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 14 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 14 |
4 files changed, 67 insertions, 3 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index c1a768f02..d90e9b102 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -282,6 +282,14 @@ namespace Slang return m_mangledName; } + String EntryPoint::getEntryPointNameOverride(Index index) + { + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + + return m_nameOverride; + } + void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo)); @@ -1211,7 +1219,11 @@ namespace Slang // Set the entry point name options.entryPointName = getText(entryPoint->getName()); - + auto entryPointNameOverride = program->getEntryPointNameOverride(entryPointIndex); + if (entryPointNameOverride.getLength() != 0) + { + options.entryPointName = entryPointNameOverride; + } if (compilerType == PassThroughMode::Dxc) { // We will enable the flag to generate proper code for 16 - bit types diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 52c03ffdb..930210a96 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -319,6 +319,9 @@ namespace Slang /// Get the mangled name of one of the entry points linked into this component type. virtual String getEntryPointMangledName(Index index) = 0; + /// Get the name override of one of the entry points linked into this component type. + virtual String getEntryPointNameOverride(Index index) = 0; + /// Get the number of global shader parameters linked into this component type. virtual Index getShaderParamCount() = 0; @@ -511,6 +514,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE; RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE; String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE; ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; @@ -560,6 +564,7 @@ namespace Slang // List<EntryPoint*> m_entryPoints; List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; List<ShaderParamInfo> m_shaderParams; List<SpecializationParam> m_specializationParams; List<ComponentType*> m_requirements; @@ -598,6 +603,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); } String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } @@ -638,6 +644,7 @@ namespace Slang RefPtr<IRModule> m_irModule; List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; // Any tagged union types that were referenced by the specialization arguments. List<TaggedUnionType*> m_taggedUnionTypes; @@ -721,6 +728,15 @@ namespace Slang return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); } + SLANG_NO_THROW SlangResult SLANG_MCALL getRenamedEntryPoint(const char* newName, IEntryPoint** outEntryPoint) + SLANG_OVERRIDE + { + RefPtr<EntryPoint> newEntryPoint = create(getLinkage(), m_funcDeclRef, m_profile); + newEntryPoint->m_nameOverride = newName; + *outEntryPoint = newEntryPoint.detach(); + return SLANG_OK; + } + /// Create an entry point that refers to the given function. static RefPtr<EntryPoint> create( Linkage* linkage, @@ -791,6 +807,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; } String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; Index getShaderParamCount() SLANG_OVERRIDE { return 0; } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return ShaderParamInfo(); } @@ -831,6 +848,9 @@ namespace Slang /// The mangled name of the entry point function String m_mangledName; + /// The name of this entry point in the compiled code. + String m_nameOverride; + SpecializationParams m_genericSpecializationParams; SpecializationParams m_existentialSpecializationParams; @@ -940,6 +960,7 @@ namespace Slang return nullptr; } String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } Index getShaderParamCount() SLANG_OVERRIDE { return 0; } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE @@ -1107,6 +1128,7 @@ namespace Slang Index getEntryPointCount() SLANG_OVERRIDE { return 0; } RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } String getEntryPointMangledName(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 77998fba1..385b536dc 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -809,7 +809,8 @@ static void maybeCopyLayoutInformationToParameters( IRFunc* specializeIRForEntryPoint( IRSpecContext* context, - String const& mangledName) + String const& mangledName, + String const& nameOverride) { // We start by looking up the IR symbol that // matches the mangled name given to the @@ -848,6 +849,20 @@ IRFunc* specializeIRForEntryPoint( // auto clonedVal = cloneGlobalValue(context, originalVal); + if (nameOverride.getLength()) + { + if (auto entryPointDecor = clonedVal->findDecoration<IREntryPointDecoration>()) + { + IRInst* operands[] = { + entryPointDecor->getProfileInst(), + context->builder->getStringValue(nameOverride.getUnownedSlice()), + entryPointDecor->getModuleName()}; + context->builder->addDecoration( + clonedVal, IROp::kIROp_EntryPointDecoration, operands, 3); + entryPointDecor->removeAndDeallocate(); + } + } + // In the case where the user is requesting a specialization // of a generic entry point, we have a bit of a problem. // @@ -1425,7 +1440,8 @@ LinkedIR linkIR( for (auto entryPointIndex : entryPointIndices) { auto entryPointMangledName = program->getEntryPointMangledName(entryPointIndex); - irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName)); + auto nameOverride = program->getEntryPointNameOverride(entryPointIndex); + irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName, nameOverride)); } // Layout information for global shader parameters is also required, diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5fbc5b4bf..1333e3660 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3325,6 +3325,7 @@ CompositeComponentType::CompositeComponentType( { m_entryPoints.add(child->getEntryPoint(cc)); m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); + m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc)); } auto childShaderParamCount = child->getShaderParamCount(); @@ -3376,6 +3377,11 @@ String CompositeComponentType::getEntryPointMangledName(Index index) return m_entryPointMangledNames[index]; } +String CompositeComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + Index CompositeComponentType::getShaderParamCount() { return m_shaderParams.getCount(); @@ -3783,6 +3789,7 @@ SpecializedComponentType::SpecializedComponentType( struct EntryPointMangledNameCollector : ComponentTypeVisitor { List<String>* mangledEntryPointNames; + List<String>* entryPointNameOverrides; void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE { @@ -3791,6 +3798,7 @@ SpecializedComponentType::SpecializedComponentType( funcDeclRef = specializationInfo->specializedFuncDeclRef; (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); + (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); } void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE @@ -3815,6 +3823,7 @@ SpecializedComponentType::SpecializedComponentType( // EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); collector.mangledEntryPointNames = &m_entryPointMangledNames; + collector.entryPointNameOverrides = &m_entryPointNameOverrides; collector.visitSpecialized(this); } @@ -3840,6 +3849,11 @@ String SpecializedComponentType::getEntryPointMangledName(Index index) return m_entryPointMangledNames[index]; } +String SpecializedComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) { auto childCount = composite->getChildComponentCount(); |
