diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-17 21:57:58 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-17 21:57:58 -0800 |
| commit | 6e24244832d9032f0993cb088af625238096b723 (patch) | |
| tree | ab31fe5d6446df4f96154e6a8272c43593be6aac | |
| parent | 6f57e47a9e1675b011f023277b47cfc768d30da8 (diff) | |
Support specializing generic entrypoints in `findAndCheckEntryPoint`. (#5890)
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 32 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-generic-entrypoint.cpp | 67 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-glsl-compile.cpp | 69 |
3 files changed, 150 insertions, 18 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 50382f9c1..c9e190469 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -229,29 +229,25 @@ bool isPrimaryDecl(CallableDecl* decl) return (!decl->primaryDecl) || (decl == decl->primaryDecl); } -FuncDecl* findFunctionDeclByName(Module* translationUnit, Name* name, DiagnosticSink* sink) +DeclRef<FuncDecl> findFunctionDeclByName(Module* translationUnit, Name* name, DiagnosticSink* sink) { - FuncDecl* entryPointFuncDecl = nullptr; + DeclRef<FuncDecl> entryPointFuncDeclRef; auto expr = translationUnit->findDeclFromString(getText(name), sink); if (auto declRefExpr = as<DeclRefExpr>(expr)) { - auto declRef = declRefExpr->declRef; - entryPointFuncDecl = declRef.as<FuncDecl>().getDecl(); + entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); - if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) - entryPointFuncDecl = nullptr; + if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit) + entryPointFuncDeclRef = DeclRef<FuncDecl>(); } - if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) - entryPointFuncDecl = nullptr; - - if (!entryPointFuncDecl) + if (!entryPointFuncDeclRef) { auto translationUnitSyntax = translationUnit->getModuleDecl(); sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); } - return entryPointFuncDecl; + return entryPointFuncDeclRef; } // Is a entry pointer parmaeter of `type` always a uniform parameter? @@ -409,8 +405,8 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) // make up the entry point. // Name* name = linkage->getNamePool()->getName(stringLit->value); - FuncDecl* patchConstantFuncDecl = findFunctionDeclByName(module, name, sink); - if (!patchConstantFuncDecl) + DeclRef<FuncDecl> patchConstantFuncDeclRef = findFunctionDeclByName(module, name, sink); + if (!patchConstantFuncDeclRef) { sink->diagnose( expr, @@ -420,7 +416,7 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) return; } - attr->patchConstantFuncDecl = patchConstantFuncDecl; + attr->patchConstantFuncDecl = patchConstantFuncDeclRef.getDecl(); } } else if (stage == Stage::Compute) @@ -648,11 +644,11 @@ RefPtr<EntryPoint> findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPoi auto sink = compileRequest->getSink(); auto entryPointName = entryPointReq->getName(); - FuncDecl* entryPointFuncDecl = + DeclRef<FuncDecl> entryPointFuncDeclRef = findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink); // Did we find a function declaration in our search? - if (!entryPointFuncDecl) + if (!entryPointFuncDeclRef) { return nullptr; } @@ -673,14 +669,14 @@ RefPtr<EntryPoint> findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPoi entryPointProfile, linkage->m_optionSet, linkage->targets, - entryPointFuncDecl, + entryPointFuncDeclRef.getDecl(), sink); // TODO: Should we attach a `[shader(...)]` attribute to an // entry point that didn't have one, so that we can have // a more uniform representation in the AST? RefPtr<EntryPoint> entryPoint = - EntryPoint::create(linkage, makeDeclRef(entryPointFuncDecl), entryPointProfile); + EntryPoint::create(linkage, entryPointFuncDeclRef, entryPointProfile); // Now that we've *found* the entry point, it is time to validate // that it actually meets the constraints for the chosen stage/profile. diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp new file mode 100644 index 000000000..741fe35bc --- /dev/null +++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp @@ -0,0 +1,67 @@ +// unit-test-generic-entrypoint.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// Test the compilation API for compiling a specialized generic entrypoint. + +SLANG_UNIT_TEST(genericEntryPointCompile) +{ + const char* userSourceBody = R"( + interface I { int getValue(); } + struct X : I { int getValue() { return 100; } } + float4 vertMain<T:I>(uniform T o) { + return float4(o.getValue(), 0, 0, 1); + } + )"; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_GLSL; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain<X>", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue")) != -1); +} diff --git a/tools/slang-unit-test/unit-test-glsl-compile.cpp b/tools/slang-unit-test/unit-test-glsl-compile.cpp new file mode 100644 index 000000000..5a49da91b --- /dev/null +++ b/tools/slang-unit-test/unit-test-glsl-compile.cpp @@ -0,0 +1,69 @@ +// unit-test-glsl-compile.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// Test the compilation API for cross-compiling glsl source to SPIRV. + +SLANG_UNIT_TEST(glslCompile) +{ + const char* userSourceBody = R"( + #version 450 core + layout(location = 0) in vec2 aPosition; + layout(location = 1) in vec4 aColor; + layout(location = 0) out vec4 vColor; + void main() { + vColor = aColor; + gl_Position = vec4(aPosition, 0, 1); + } + )"; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "main", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK(code != nullptr); +} |
