diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5f22f8a23..4ae3f4654 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -7,6 +7,7 @@ #include "slang-parameter-binding.h" #include "slang-lower-to-ir.h" +#include "slang-mangle.h" #include "slang-parser.h" #include "slang-preprocessor.h" #include "slang-reflection.h" @@ -1136,6 +1137,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() { auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); + targetProgram->getOrCreateIRModuleForLayout(getSink()); } if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -2033,6 +2035,7 @@ CompositeComponentType::CompositeComponentType( for(Index cc = 0; cc < childEntryPointCount; ++cc) { m_entryPoints.add(child->getEntryPoint(cc)); + m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); } auto childShaderParamCount = child->getShaderParamCount(); @@ -2079,6 +2082,11 @@ RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) return m_entryPoints[index]; } +String CompositeComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + Index CompositeComponentType::getShaderParamCount() { return m_shaderParams.getCount(); @@ -2198,6 +2206,48 @@ SpecializedComponentType::SpecializedComponentType( m_taggedUnionTypes.add(taggedUnionType); } + + // Because we are specializing shader code, the mangled entry + // point names for this component type may be different than + // for the base component type (e.g., the mangled name for `f<int>` + // is different than that that of the generic `f` function + // itself). + // + // We will compute the mangled names of all the entry points and + // store them here, so that we don't have to do it on the fly. + // Because the `ComponentType` structure is hierarchical, we + // need to use a recursive visitor to compute the names, + // and we will define that visitor locally: + // + struct EntryPointMangledNameCollector : ComponentTypeVisitor + { + List<String>* mangledEntryPointNames; + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + auto funcDeclRef = entryPoint->getFuncDeclRef(); + if(specializationInfo) + funcDeclRef = specializationInfo->specializedFuncDeclRef; + + (*mangledEntryPointNames).add(getMangledName(funcDeclRef)); + } + + void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + {} + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { visitChildren(composite, specializationInfo); } + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { visitChildren(specialized); } + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { visitChildren(legacy, specializationInfo); } + }; + + // With the visitor defined, we apply it to ourself to compute + // and collect the mangled entry point names. + // + EntryPointMangledNameCollector collector; + collector.mangledEntryPointNames = &m_entryPointMangledNames; + collector.visitSpecialized(this); } void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) @@ -2221,6 +2271,11 @@ RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) return m_base->getRequirement(index); } +String SpecializedComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + // // LegacyProgram // |
