diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2022-06-08 10:23:01 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-08 10:23:01 -0400 |
| commit | 8e6e884eca5b33218a8cb2714266fb6ed4548d75 (patch) | |
| tree | a9c8aee79a71450a64e6660da7266b6a45da0264 /tools | |
| parent | 01d0154ae90f5c587321d39b8fd8f82e2764f360 (diff) | |
Actual global support (#2262)
* #include an absolute path didn't work - because paths were taken to always be relative.
* Use TerminatedUnownedStringSlice for literals in output C++.
* Remove Escape/Unescape functions used in slang-token-reader.cpp
Add target type of 'host-cpp' etc to map to the target types.
* Fix some corner cases around string encoding.
* Added unit test for string escaping.
Fixed some assorted escaping bugs.
* Updated test output.
* Added decode test.
* Stop using hex output, to get around 'greedy' aspect. Use octal instead.
* Added HostHostCallable
Small changes to use ArtifactDesc/Info instead of large switches.
* Fix C++ emit to handle arbitrary function export.
* Add options handling for callable without an output being specified.
* Can compile with COM interface. Added example using com interface.
* Use the IR Ptr type instead of hack in C++ emit for interfaces.
* Fix issue with outputting the COM call when ptr is used.
* Fix crash issue on compilation failure.
* Add support for __global.
* Added `ActualGlobalRate`
Added special handling around globals and COM interfaces.
Tested out in cpu-com-example.
* Fix typo in NodeBase.
* Support for accessing globals by name working.
* Check that actual global initialization is working.
* Refactor the com replacement such that it doesn't need a cache or do anything special with GlobalVar.
* Remove context.
Only create replacement if needed.
* Split out COM host-callable into a unit-test.
* host-callable com testing on C++and llvm.
* Comment around the COM ptr replacement.
* Disable com test on vs 32 bit.
Fix C++ prelude
* Disable 32 bit targets testing com host-callable.
* Use JSON parsing to locate VS version.
* Need platform detection in C++prelude.
* Fix com host callable test for LLVM.
* Work around for not being able to include "targetConditionals.h"
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-com-host-callable.cpp | 335 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-com-host-callable.slang | 50 | ||||
| -rw-r--r-- | tools/unit-test/slang-unit-test.h | 2 |
3 files changed, 386 insertions, 1 deletions
diff --git a/tools/slang-unit-test/unit-test-com-host-callable.cpp b/tools/slang-unit-test/unit-test-com-host-callable.cpp new file mode 100644 index 000000000..57209fbde --- /dev/null +++ b/tools/slang-unit-test/unit-test-com-host-callable.cpp @@ -0,0 +1,335 @@ +// unit-test-com-host-callable.cpp + +#include "../../source/core/slang-byte-encode-util.h" + +#include <stdio.h> +#include <stdlib.h> + +#include "tools/unit-test/slang-unit-test.h" + +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../../slang-com-ptr.h" + +#include "../../source/core/slang-list.h" + +namespace { // anonymous + +// Slang namespace is used for elements support code (like core) which we use here +// for ComPtr<> and TestToolUtil +using namespace Slang; + +// For the moment we have to explicitly write the Slang COM interface in C++ code. It *MUST* match +// the interface in the slang source +// As it stands all interfaces need to derive from ISlangUnknown (or IUnknown). +class IDoThings : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) = 0; + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) = 0; +}; + +class ICountGood : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW int SLANG_MCALL nextCount() = 0; +}; + +static int _calcHash(const char* in) +{ + int hash = 0; + for (; *in; ++in) + { + // A very poor hash function + hash = hash * 13 + *in; + } + return hash; +} + +class DoThings : public IDoThings +{ +public: + // We don't need queryInterface for this impl, or ref counting + virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; } + + // IDoThings + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) SLANG_OVERRIDE { return a + b + 1; } + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) SLANG_OVERRIDE { return (int)_calcHash(in); } +}; + +class CountGood : public ICountGood +{ +public: + // We don't need queryInterface for this impl, or ref counting + virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; } + + // ICountGood + virtual SLANG_NO_THROW int SLANG_MCALL nextCount() SLANG_OVERRIDE { return m_count++; } + + int m_count = 0; +}; + +struct ComTestContext +{ + ComTestContext(UnitTestContext* context): + m_unitTestContext(context) + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + m_defaultCppCompiler = slangSession->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP); + + m_hostHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE); + m_shaderHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE); + + } + + SlangResult runTests() + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + // TODO(JS): + // Care is needed around this in normal testing. `slang-llvm` is whatever was asked for for when premake was built + // when the target is specified. Otherwise it is the `default` which is typically 64 bit during development. + // + // On CI we should be okay, because it should download the correct `slang-llvm` for the build (as it packages up with it). + // But for normal development, that can easily not be the case (for example changing to 32 bit build in VS is a problem). + // + // Make sure to run + // + // ``` + // premake --arch=x86 --deps=true + // ``` + // + // for the actual target/arch(!) + + const bool hasLlvm = SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(SLANG_PASS_THROUGH_LLVM)); + + SlangPassThrough cppCompiler = SLANG_PASS_THROUGH_NONE; + + { + const SlangPassThrough cppCompilers[] = + { + SLANG_PASS_THROUGH_VISUAL_STUDIO, + SLANG_PASS_THROUGH_GCC, + SLANG_PASS_THROUGH_CLANG, + }; + // Do we have a C++ compiler + for (const auto compiler : cppCompilers) + { + if (SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(compiler))) + { + cppCompiler = compiler; + break; + } + } + } + + // If we have an *actual* C++ compile rtest on that first + if (cppCompiler != SLANG_PASS_THROUGH_NONE) + { + slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, cppCompiler); + + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, cppCompiler); + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, cppCompiler); + + SLANG_RETURN_ON_FAIL(_runTest()); + } + + // Reset the compiler that's used for host-callable + _reset(); + + // If we have Llvm it is the default host callable compiler + if (hasLlvm) + { + // Should run via slang-llvm + SLANG_RETURN_ON_FAIL(_runTest()); + } + + return SLANG_OK; + } + + void _reset() + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, m_defaultCppCompiler); + + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, m_shaderHostCallableCompiler); + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, m_hostHostCallableCompiler); + } + + ~ComTestContext() + { + _reset(); + } + + SlangResult _runTest(); + + UnitTestContext* m_unitTestContext; + + SlangPassThrough m_defaultCppCompiler; + SlangPassThrough m_hostHostCallableCompiler; + SlangPassThrough m_shaderHostCallableCompiler; +}; + +SlangResult ComTestContext::_runTest() +{ + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + // Create a compile request + Slang::ComPtr<slang::ICompileRequest> request; + SLANG_RETURN_ON_FAIL(slangSession->createCompileRequest(request.writeRef())); + + // We want to compile to 'HOST_CALLABLE' here such that we can execute the Slang code. + // + // Note that it is possible to use HOST_HOST_CALLABLE, but this currently only works with 'regular' C++ compilers + // not with `slang-llvm`. + const int targetIndex = request->addCodeGenTarget(SLANG_SHADER_HOST_CALLABLE); + + // Set the target flag to indicate that we want to compile all into a library. + request->setTargetFlags(targetIndex, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM); + + request->setOptimizationLevel(SLANG_OPTIMIZATION_LEVEL_NONE); + request->setDebugInfoLevel(SLANG_DEBUG_INFO_LEVEL_STANDARD); + + // Add the translation unit + const int translationUnitIndex = request->addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, nullptr); + + // Set the source file for the translation unit + request->addTranslationUnitSourceFile(translationUnitIndex, "tools/slang-unit-test/unit-test-com-host-callable.slang"); + + const SlangResult compileRes = request->compile(); + + // Even if there were no errors that forced compilation to fail, the + // compiler may have produced "diagnostic" output such as warnings. + // We will go ahead and print that output here. + // + if (auto diagnostics = request->getDiagnosticOutput()) + { + printf("%s", diagnostics); + } + + // Get the 'shared library' (note that this doesn't necessarily have to be implemented as a shared library + // it's just an interface to executable code). + ComPtr<ISlangSharedLibrary> sharedLibrary; + SLANG_RETURN_ON_FAIL(request->getTargetHostCallable(0, sharedLibrary.writeRef())); + + { + typedef const char* (*Func)(const char*); + Func func = (Func)sharedLibrary->findFuncByName("getString"); + + if (!func) + { + return SLANG_FAIL; + } + + String text = "Hello World!"; + String returnedText = func(text.getBuffer()); + + SLANG_CHECK(text == returnedText); + } + { + typedef int (*Func)(const char* text, IDoThings* doThings); + + Func func = (Func)sharedLibrary->findFuncByName("calcHash"); + + if (!func) + { + return SLANG_FAIL; + } + + DoThings doThings; + + String text("Hello"); + + const int hash = func(text.getBuffer(), &doThings); + + SLANG_CHECK(hash == _calcHash(text.getBuffer())); + } + + // Check accessing a global + { + typedef void (*SetFunc)(int v); + typedef int (*GetFunc)(); + + const auto setGlobal = (SetFunc)sharedLibrary->findFuncByName("setGlobal"); + const auto getGlobal = (GetFunc)sharedLibrary->findFuncByName("getGlobal"); + + if (setGlobal == nullptr || getGlobal == nullptr) + { + return SLANG_FAIL; + } + + // In the slang source it is set a default value + SLANG_CHECK(getGlobal() == 10); + + for (Index i = 0; i < 10; ++i) + { + setGlobal(int(i)); + SLANG_CHECK(getGlobal() == i); + } + } + + // Check using a global interface + { + + typedef void (*SetCounterFunc)(ICountGood* counter); + typedef int (*NextCountFunc)(); + + const auto setCounter = (SetCounterFunc)sharedLibrary->findFuncByName("setCounter"); + const auto nextCount = (NextCountFunc)sharedLibrary->findFuncByName("nextCount"); + + if (setCounter == nullptr || nextCount == nullptr) + { + return SLANG_FAIL; + } + + CountGood counter; + + ICountGood* counterIntf = &counter; + + setCounter(counterIntf); + + auto counterPtr = (ICountGood**)sharedLibrary->findSymbolAddressByName("globalCounter"); + SLANG_CHECK(counterPtr); + if (!counterPtr) + { + return SLANG_FAIL; + } + + for (Index i = 0; i < 10; ++i) + { + SLANG_CHECK(*counterPtr == &counter); + + const auto v = nextCount(); + SLANG_CHECK(v == i); + } + } + + return SLANG_OK; +} + +} // anonymous + +SLANG_UNIT_TEST(comHostCallable) +{ +#if SLANG_PTR_IS_32 && !SLANG_MICROSOFT_FAMILY + // TODO(JS): + // We can't currently run this test reliably on targets other than windows + // Visual Studio DownstreamCompiler has support for 32 bit builds + // Other targets generally build for the native environment which is almost always 64 bit, + // and it requires other features to build/test 32 bit binaries on such systems. + // + // So we disable for any 32 bit non MS target for now + return; +#endif + + ComTestContext context(unitTestContext); + + const auto result = context.runTests(); + + SLANG_CHECK(SLANG_SUCCEEDED(result)); +} diff --git a/tools/slang-unit-test/unit-test-com-host-callable.slang b/tools/slang-unit-test/unit-test-com-host-callable.slang new file mode 100644 index 000000000..b591904b3 --- /dev/null +++ b/tools/slang-unit-test/unit-test-com-host-callable.slang @@ -0,0 +1,50 @@ +// shader.slang + +// Example of using 'NativeString' + +public __extern_cpp NativeString getString(NativeString in) +{ + return in; +} + +public __extern_cpp __global int intGlobal = 10; + +public __extern_cpp void setGlobal(int v) +{ + intGlobal = v; +} + +public __extern_cpp int getGlobal() +{ + return intGlobal; +} + +[COM] +interface IDoThings +{ + int doThing(int a, int b); + int calcHash(NativeString in); +} + +[COM] +interface ICountGood +{ + int nextCount(); +} + +public __extern_cpp __global ICountGood globalCounter; + +public __extern_cpp void setCounter(ICountGood counter) +{ + globalCounter = counter; +} + +public __extern_cpp int nextCount() +{ + return globalCounter.nextCount(); +} + +public __extern_cpp int calcHash(NativeString text, IDoThings doThings) +{ + return doThings.calcHash(text); +} diff --git a/tools/unit-test/slang-unit-test.h b/tools/unit-test/slang-unit-test.h index 035f67e33..3c8bcaea6 100644 --- a/tools/unit-test/slang-unit-test.h +++ b/tools/unit-test/slang-unit-test.h @@ -67,7 +67,7 @@ typedef IUnitTestModule* (*UnitTestGetModuleFunc)(); void _##name##_impl(UnitTestContext* unitTestContext); \ void name(UnitTestContext* unitTestContext)\ {\ - try { _##name##_impl(unitTestContext); } catch (AbortTestException){} \ + try { _##name##_impl(unitTestContext); } catch (AbortTestException&){} \ }\ UnitTestRegisterHelper _##name##RegisterHelper(#name, name); \ void _##name##_impl(UnitTestContext* unitTestContext) |
