summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2022-06-08 10:23:01 -0400
committerGitHub <noreply@github.com>2022-06-08 10:23:01 -0400
commit8e6e884eca5b33218a8cb2714266fb6ed4548d75 (patch)
treea9c8aee79a71450a64e6660da7266b6a45da0264 /tools
parent01d0154ae90f5c587321d39b8fd8f82e2764f360 (diff)
Actual global support (#2262)
* #include an absolute path didn't work - because paths were taken to always be relative. * Use TerminatedUnownedStringSlice for literals in output C++. * Remove Escape/Unescape functions used in slang-token-reader.cpp Add target type of 'host-cpp' etc to map to the target types. * Fix some corner cases around string encoding. * Added unit test for string escaping. Fixed some assorted escaping bugs. * Updated test output. * Added decode test. * Stop using hex output, to get around 'greedy' aspect. Use octal instead. * Added HostHostCallable Small changes to use ArtifactDesc/Info instead of large switches. * Fix C++ emit to handle arbitrary function export. * Add options handling for callable without an output being specified. * Can compile with COM interface. Added example using com interface. * Use the IR Ptr type instead of hack in C++ emit for interfaces. * Fix issue with outputting the COM call when ptr is used. * Fix crash issue on compilation failure. * Add support for __global. * Added `ActualGlobalRate` Added special handling around globals and COM interfaces. Tested out in cpu-com-example. * Fix typo in NodeBase. * Support for accessing globals by name working. * Check that actual global initialization is working. * Refactor the com replacement such that it doesn't need a cache or do anything special with GlobalVar. * Remove context. Only create replacement if needed. * Split out COM host-callable into a unit-test. * host-callable com testing on C++and llvm. * Comment around the COM ptr replacement. * Disable com test on vs 32 bit. Fix C++ prelude * Disable 32 bit targets testing com host-callable. * Use JSON parsing to locate VS version. * Need platform detection in C++prelude. * Fix com host callable test for LLVM. * Work around for not being able to include "targetConditionals.h"
Diffstat (limited to 'tools')
-rw-r--r--tools/slang-unit-test/unit-test-com-host-callable.cpp335
-rw-r--r--tools/slang-unit-test/unit-test-com-host-callable.slang50
-rw-r--r--tools/unit-test/slang-unit-test.h2
3 files changed, 386 insertions, 1 deletions
diff --git a/tools/slang-unit-test/unit-test-com-host-callable.cpp b/tools/slang-unit-test/unit-test-com-host-callable.cpp
new file mode 100644
index 000000000..57209fbde
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-com-host-callable.cpp
@@ -0,0 +1,335 @@
+// unit-test-com-host-callable.cpp
+
+#include "../../source/core/slang-byte-encode-util.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../../slang-com-ptr.h"
+
+#include "../../source/core/slang-list.h"
+
+namespace { // anonymous
+
+// Slang namespace is used for elements support code (like core) which we use here
+// for ComPtr<> and TestToolUtil
+using namespace Slang;
+
+// For the moment we have to explicitly write the Slang COM interface in C++ code. It *MUST* match
+// the interface in the slang source
+// As it stands all interfaces need to derive from ISlangUnknown (or IUnknown).
+class IDoThings : public ISlangUnknown
+{
+public:
+ virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) = 0;
+ virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) = 0;
+};
+
+class ICountGood : public ISlangUnknown
+{
+public:
+ virtual SLANG_NO_THROW int SLANG_MCALL nextCount() = 0;
+};
+
+static int _calcHash(const char* in)
+{
+ int hash = 0;
+ for (; *in; ++in)
+ {
+ // A very poor hash function
+ hash = hash * 13 + *in;
+ }
+ return hash;
+}
+
+class DoThings : public IDoThings
+{
+public:
+ // We don't need queryInterface for this impl, or ref counting
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; }
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; }
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; }
+
+ // IDoThings
+ virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) SLANG_OVERRIDE { return a + b + 1; }
+ virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) SLANG_OVERRIDE { return (int)_calcHash(in); }
+};
+
+class CountGood : public ICountGood
+{
+public:
+ // We don't need queryInterface for this impl, or ref counting
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; }
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; }
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; }
+
+ // ICountGood
+ virtual SLANG_NO_THROW int SLANG_MCALL nextCount() SLANG_OVERRIDE { return m_count++; }
+
+ int m_count = 0;
+};
+
+struct ComTestContext
+{
+ ComTestContext(UnitTestContext* context):
+ m_unitTestContext(context)
+ {
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+
+ m_defaultCppCompiler = slangSession->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP);
+
+ m_hostHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE);
+ m_shaderHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE);
+
+ }
+
+ SlangResult runTests()
+ {
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+
+ // TODO(JS):
+ // Care is needed around this in normal testing. `slang-llvm` is whatever was asked for for when premake was built
+ // when the target is specified. Otherwise it is the `default` which is typically 64 bit during development.
+ //
+ // On CI we should be okay, because it should download the correct `slang-llvm` for the build (as it packages up with it).
+ // But for normal development, that can easily not be the case (for example changing to 32 bit build in VS is a problem).
+ //
+ // Make sure to run
+ //
+ // ```
+ // premake --arch=x86 --deps=true
+ // ```
+ //
+ // for the actual target/arch(!)
+
+ const bool hasLlvm = SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(SLANG_PASS_THROUGH_LLVM));
+
+ SlangPassThrough cppCompiler = SLANG_PASS_THROUGH_NONE;
+
+ {
+ const SlangPassThrough cppCompilers[] =
+ {
+ SLANG_PASS_THROUGH_VISUAL_STUDIO,
+ SLANG_PASS_THROUGH_GCC,
+ SLANG_PASS_THROUGH_CLANG,
+ };
+ // Do we have a C++ compiler
+ for (const auto compiler : cppCompilers)
+ {
+ if (SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(compiler)))
+ {
+ cppCompiler = compiler;
+ break;
+ }
+ }
+ }
+
+ // If we have an *actual* C++ compile rtest on that first
+ if (cppCompiler != SLANG_PASS_THROUGH_NONE)
+ {
+ slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, cppCompiler);
+
+ slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, cppCompiler);
+ slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, cppCompiler);
+
+ SLANG_RETURN_ON_FAIL(_runTest());
+ }
+
+ // Reset the compiler that's used for host-callable
+ _reset();
+
+ // If we have Llvm it is the default host callable compiler
+ if (hasLlvm)
+ {
+ // Should run via slang-llvm
+ SLANG_RETURN_ON_FAIL(_runTest());
+ }
+
+ return SLANG_OK;
+ }
+
+ void _reset()
+ {
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+ slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, m_defaultCppCompiler);
+
+ slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, m_shaderHostCallableCompiler);
+ slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, m_hostHostCallableCompiler);
+ }
+
+ ~ComTestContext()
+ {
+ _reset();
+ }
+
+ SlangResult _runTest();
+
+ UnitTestContext* m_unitTestContext;
+
+ SlangPassThrough m_defaultCppCompiler;
+ SlangPassThrough m_hostHostCallableCompiler;
+ SlangPassThrough m_shaderHostCallableCompiler;
+};
+
+SlangResult ComTestContext::_runTest()
+{
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+
+ // Create a compile request
+ Slang::ComPtr<slang::ICompileRequest> request;
+ SLANG_RETURN_ON_FAIL(slangSession->createCompileRequest(request.writeRef()));
+
+ // We want to compile to 'HOST_CALLABLE' here such that we can execute the Slang code.
+ //
+ // Note that it is possible to use HOST_HOST_CALLABLE, but this currently only works with 'regular' C++ compilers
+ // not with `slang-llvm`.
+ const int targetIndex = request->addCodeGenTarget(SLANG_SHADER_HOST_CALLABLE);
+
+ // Set the target flag to indicate that we want to compile all into a library.
+ request->setTargetFlags(targetIndex, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
+
+ request->setOptimizationLevel(SLANG_OPTIMIZATION_LEVEL_NONE);
+ request->setDebugInfoLevel(SLANG_DEBUG_INFO_LEVEL_STANDARD);
+
+ // Add the translation unit
+ const int translationUnitIndex = request->addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, nullptr);
+
+ // Set the source file for the translation unit
+ request->addTranslationUnitSourceFile(translationUnitIndex, "tools/slang-unit-test/unit-test-com-host-callable.slang");
+
+ const SlangResult compileRes = request->compile();
+
+ // Even if there were no errors that forced compilation to fail, the
+ // compiler may have produced "diagnostic" output such as warnings.
+ // We will go ahead and print that output here.
+ //
+ if (auto diagnostics = request->getDiagnosticOutput())
+ {
+ printf("%s", diagnostics);
+ }
+
+ // Get the 'shared library' (note that this doesn't necessarily have to be implemented as a shared library
+ // it's just an interface to executable code).
+ ComPtr<ISlangSharedLibrary> sharedLibrary;
+ SLANG_RETURN_ON_FAIL(request->getTargetHostCallable(0, sharedLibrary.writeRef()));
+
+ {
+ typedef const char* (*Func)(const char*);
+ Func func = (Func)sharedLibrary->findFuncByName("getString");
+
+ if (!func)
+ {
+ return SLANG_FAIL;
+ }
+
+ String text = "Hello World!";
+ String returnedText = func(text.getBuffer());
+
+ SLANG_CHECK(text == returnedText);
+ }
+ {
+ typedef int (*Func)(const char* text, IDoThings* doThings);
+
+ Func func = (Func)sharedLibrary->findFuncByName("calcHash");
+
+ if (!func)
+ {
+ return SLANG_FAIL;
+ }
+
+ DoThings doThings;
+
+ String text("Hello");
+
+ const int hash = func(text.getBuffer(), &doThings);
+
+ SLANG_CHECK(hash == _calcHash(text.getBuffer()));
+ }
+
+ // Check accessing a global
+ {
+ typedef void (*SetFunc)(int v);
+ typedef int (*GetFunc)();
+
+ const auto setGlobal = (SetFunc)sharedLibrary->findFuncByName("setGlobal");
+ const auto getGlobal = (GetFunc)sharedLibrary->findFuncByName("getGlobal");
+
+ if (setGlobal == nullptr || getGlobal == nullptr)
+ {
+ return SLANG_FAIL;
+ }
+
+ // In the slang source it is set a default value
+ SLANG_CHECK(getGlobal() == 10);
+
+ for (Index i = 0; i < 10; ++i)
+ {
+ setGlobal(int(i));
+ SLANG_CHECK(getGlobal() == i);
+ }
+ }
+
+ // Check using a global interface
+ {
+
+ typedef void (*SetCounterFunc)(ICountGood* counter);
+ typedef int (*NextCountFunc)();
+
+ const auto setCounter = (SetCounterFunc)sharedLibrary->findFuncByName("setCounter");
+ const auto nextCount = (NextCountFunc)sharedLibrary->findFuncByName("nextCount");
+
+ if (setCounter == nullptr || nextCount == nullptr)
+ {
+ return SLANG_FAIL;
+ }
+
+ CountGood counter;
+
+ ICountGood* counterIntf = &counter;
+
+ setCounter(counterIntf);
+
+ auto counterPtr = (ICountGood**)sharedLibrary->findSymbolAddressByName("globalCounter");
+ SLANG_CHECK(counterPtr);
+ if (!counterPtr)
+ {
+ return SLANG_FAIL;
+ }
+
+ for (Index i = 0; i < 10; ++i)
+ {
+ SLANG_CHECK(*counterPtr == &counter);
+
+ const auto v = nextCount();
+ SLANG_CHECK(v == i);
+ }
+ }
+
+ return SLANG_OK;
+}
+
+} // anonymous
+
+SLANG_UNIT_TEST(comHostCallable)
+{
+#if SLANG_PTR_IS_32 && !SLANG_MICROSOFT_FAMILY
+ // TODO(JS):
+ // We can't currently run this test reliably on targets other than windows
+ // Visual Studio DownstreamCompiler has support for 32 bit builds
+ // Other targets generally build for the native environment which is almost always 64 bit,
+ // and it requires other features to build/test 32 bit binaries on such systems.
+ //
+ // So we disable for any 32 bit non MS target for now
+ return;
+#endif
+
+ ComTestContext context(unitTestContext);
+
+ const auto result = context.runTests();
+
+ SLANG_CHECK(SLANG_SUCCEEDED(result));
+}
diff --git a/tools/slang-unit-test/unit-test-com-host-callable.slang b/tools/slang-unit-test/unit-test-com-host-callable.slang
new file mode 100644
index 000000000..b591904b3
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-com-host-callable.slang
@@ -0,0 +1,50 @@
+// shader.slang
+
+// Example of using 'NativeString'
+
+public __extern_cpp NativeString getString(NativeString in)
+{
+ return in;
+}
+
+public __extern_cpp __global int intGlobal = 10;
+
+public __extern_cpp void setGlobal(int v)
+{
+ intGlobal = v;
+}
+
+public __extern_cpp int getGlobal()
+{
+ return intGlobal;
+}
+
+[COM]
+interface IDoThings
+{
+ int doThing(int a, int b);
+ int calcHash(NativeString in);
+}
+
+[COM]
+interface ICountGood
+{
+ int nextCount();
+}
+
+public __extern_cpp __global ICountGood globalCounter;
+
+public __extern_cpp void setCounter(ICountGood counter)
+{
+ globalCounter = counter;
+}
+
+public __extern_cpp int nextCount()
+{
+ return globalCounter.nextCount();
+}
+
+public __extern_cpp int calcHash(NativeString text, IDoThings doThings)
+{
+ return doThings.calcHash(text);
+}
diff --git a/tools/unit-test/slang-unit-test.h b/tools/unit-test/slang-unit-test.h
index 035f67e33..3c8bcaea6 100644
--- a/tools/unit-test/slang-unit-test.h
+++ b/tools/unit-test/slang-unit-test.h
@@ -67,7 +67,7 @@ typedef IUnitTestModule* (*UnitTestGetModuleFunc)();
void _##name##_impl(UnitTestContext* unitTestContext); \
void name(UnitTestContext* unitTestContext)\
{\
- try { _##name##_impl(unitTestContext); } catch (AbortTestException){} \
+ try { _##name##_impl(unitTestContext); } catch (AbortTestException&){} \
}\
UnitTestRegisterHelper _##name##RegisterHelper(#name, name); \
void _##name##_impl(UnitTestContext* unitTestContext)