summaryrefslogtreecommitdiffstats
path: root/source/slang/slang.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang.cpp')
-rw-r--r--source/slang/slang.cpp55
1 files changed, 55 insertions, 0 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 5f22f8a23..4ae3f4654 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -7,6 +7,7 @@
#include "slang-parameter-binding.h"
#include "slang-lower-to-ir.h"
+#include "slang-mangle.h"
#include "slang-parser.h"
#include "slang-preprocessor.h"
#include "slang-reflection.h"
@@ -1136,6 +1137,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
{
auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq);
targetProgram->getOrCreateLayout(getSink());
+ targetProgram->getOrCreateIRModuleForLayout(getSink());
}
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -2033,6 +2035,7 @@ CompositeComponentType::CompositeComponentType(
for(Index cc = 0; cc < childEntryPointCount; ++cc)
{
m_entryPoints.add(child->getEntryPoint(cc));
+ m_entryPointMangledNames.add(child->getEntryPointMangledName(cc));
}
auto childShaderParamCount = child->getShaderParamCount();
@@ -2079,6 +2082,11 @@ RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index)
return m_entryPoints[index];
}
+String CompositeComponentType::getEntryPointMangledName(Index index)
+{
+ return m_entryPointMangledNames[index];
+}
+
Index CompositeComponentType::getShaderParamCount()
{
return m_shaderParams.getCount();
@@ -2198,6 +2206,48 @@ SpecializedComponentType::SpecializedComponentType(
m_taggedUnionTypes.add(taggedUnionType);
}
+
+ // Because we are specializing shader code, the mangled entry
+ // point names for this component type may be different than
+ // for the base component type (e.g., the mangled name for `f<int>`
+ // is different than that that of the generic `f` function
+ // itself).
+ //
+ // We will compute the mangled names of all the entry points and
+ // store them here, so that we don't have to do it on the fly.
+ // Because the `ComponentType` structure is hierarchical, we
+ // need to use a recursive visitor to compute the names,
+ // and we will define that visitor locally:
+ //
+ struct EntryPointMangledNameCollector : ComponentTypeVisitor
+ {
+ List<String>* mangledEntryPointNames;
+
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ auto funcDeclRef = entryPoint->getFuncDeclRef();
+ if(specializationInfo)
+ funcDeclRef = specializationInfo->specializedFuncDeclRef;
+
+ (*mangledEntryPointNames).add(getMangledName(funcDeclRef));
+ }
+
+ void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
+ {}
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ { visitChildren(composite, specializationInfo); }
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ { visitChildren(specialized); }
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ { visitChildren(legacy, specializationInfo); }
+ };
+
+ // With the visitor defined, we apply it to ourself to compute
+ // and collect the mangled entry point names.
+ //
+ EntryPointMangledNameCollector collector;
+ collector.mangledEntryPointNames = &m_entryPointMangledNames;
+ collector.visitSpecialized(this);
}
void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
@@ -2221,6 +2271,11 @@ RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index)
return m_base->getRequirement(index);
}
+String SpecializedComponentType::getEntryPointMangledName(Index index)
+{
+ return m_entryPointMangledNames[index];
+}
+
//
// LegacyProgram
//