summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-10-09 00:39:38 -0700
committerGitHub <noreply@github.com>2024-10-09 00:39:38 -0700
commitbea1394ad35680940a0b69b9c67efc43764cc194 (patch)
tree903eb3befc070a257a85f6522dbd9d5a48950dcb
parent132111ab0582493e09898222b275d512719a92b0 (diff)
Fix bug related to findAndCheckEntrypoint. (#5241)
-rw-r--r--source/slang/slang-ir-link.cpp50
-rw-r--r--tools/slang-unit-test/unit-test-find-check-entrypoint.cpp10
2 files changed, 37 insertions, 23 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 01b1c20de..b44c0cd5e 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -908,8 +908,8 @@ static void maybeCopyLayoutInformationToParameters(
IRFunc* specializeIRForEntryPoint(
IRSpecContext* context,
- String const& mangledName,
- String const& nameOverride)
+ EntryPoint* entryPoint,
+ UnownedStringSlice nameOverride)
{
// We start by looking up the IR symbol that
// matches the mangled name given to the
@@ -921,6 +921,7 @@ IRFunc* specializeIRForEntryPoint(
// not the same as the mangled name of the decl.
//
RefPtr<IRSpecSymbol> sym;
+ auto mangledName = entryPoint->getEntryPointMangledName(0);
if (!context->getSymbols().tryGetValue(mangledName, sym))
{
String hashedName = getHashedName(mangledName.getUnownedSlice());
@@ -948,20 +949,6 @@ IRFunc* specializeIRForEntryPoint(
//
auto clonedVal = cloneGlobalValue(context, originalVal);
- if (nameOverride.getLength())
- {
- if (auto entryPointDecor = clonedVal->findDecoration<IREntryPointDecoration>())
- {
- IRInst* operands[] = {
- entryPointDecor->getProfileInst(),
- context->builder->getStringValue(nameOverride.getUnownedSlice()),
- entryPointDecor->getModuleName()};
- context->builder->addDecoration(
- clonedVal, IROp::kIROp_EntryPointDecoration, operands, 3);
- entryPointDecor->removeAndDeallocate();
- }
- }
-
// In the case where the user is requesting a specialization
// of a generic entry point, we have a bit of a problem.
//
@@ -1023,6 +1010,34 @@ IRFunc* specializeIRForEntryPoint(
context->builder->addKeepAliveDecoration(clonedFunc);
}
+ // If an IREntryPointDecoration already exist in the function,
+ // check if we need to update its name with nameOverride.
+ // If the decoration doesn't exist, create it with the desired name.
+ if (auto entryPointDecor = clonedFunc->findDecoration<IREntryPointDecoration>())
+ {
+ if (nameOverride.getLength())
+ {
+ IRInst* operands[] = {
+ entryPointDecor->getProfileInst(),
+ context->builder->getStringValue(nameOverride),
+ entryPointDecor->getModuleName() };
+ context->builder->addDecoration(
+ clonedFunc, IROp::kIROp_EntryPointDecoration, operands, 3);
+ entryPointDecor->removeAndDeallocate();
+ }
+ }
+ else
+ {
+ if (!nameOverride.getLength())
+ nameOverride = getUnownedStringSliceText(entryPoint->getName());
+ IRInst* operands[] = {
+ context->builder->getIntValue(context->builder->getIntType(), entryPoint->getProfile().raw),
+ context->builder->getStringValue(nameOverride),
+ context->builder->getStringValue(UnownedStringSlice(entryPoint->getModule()->getName())) };
+ context->builder->addDecoration(
+ clonedFunc, IROp::kIROp_EntryPointDecoration, operands, 3);
+ }
+
// We will also go on and attach layout information
// to the function parameters, so that we have it
// available directly on the parameters, rather
@@ -1803,7 +1818,8 @@ LinkedIR linkIR(
{
auto entryPointMangledName = program->getEntryPointMangledName(entryPointIndex);
auto nameOverride = program->getEntryPointNameOverride(entryPointIndex);
- irEntryPoints.add(specializeIRForEntryPoint(context, entryPointMangledName, nameOverride));
+ auto entryPoint = program->getEntryPoint(entryPointIndex).get();
+ irEntryPoints.add(specializeIRForEntryPoint(context, entryPoint, nameOverride.getUnownedSlice()));
}
// Layout information for global shader parameters is also required,
diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
index 423b83b65..122f26ddd 100644
--- a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
+++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
@@ -19,7 +19,7 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
{
// Source for a module that contains an undecorated entrypoint.
const char* userSourceBody = R"(
- float4 fragMain(float4 pos:SV_Position) : SV_Position
+ float4 fragMain(float4 pos:SV_Position) : SV_Target
{
return pos;
}
@@ -30,8 +30,8 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
- targetDesc.format = SLANG_HLSL;
- targetDesc.profile = globalSession->findProfile("sm_5_0");
+ targetDesc.format = SLANG_SPIRV;
+ targetDesc.profile = globalSession->findProfile("spirv_1_5");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
@@ -58,8 +58,6 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
ComPtr<slang::IBlob> code;
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(code != nullptr);
-
- auto codeSrc = UnownedStringSlice((const char*)code->getBufferPointer());
- SLANG_CHECK(codeSrc.indexOf(toSlice("fragMain")) != -1);
+ SLANG_CHECK(code->getBufferSize() != 0);
}