diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2022-06-08 10:23:01 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-08 10:23:01 -0400 |
| commit | 8e6e884eca5b33218a8cb2714266fb6ed4548d75 (patch) | |
| tree | a9c8aee79a71450a64e6660da7266b6a45da0264 | |
| parent | 01d0154ae90f5c587321d39b8fd8f82e2764f360 (diff) | |
Actual global support (#2262)
* #include an absolute path didn't work - because paths were taken to always be relative.
* Use TerminatedUnownedStringSlice for literals in output C++.
* Remove Escape/Unescape functions used in slang-token-reader.cpp
Add target type of 'host-cpp' etc to map to the target types.
* Fix some corner cases around string encoding.
* Added unit test for string escaping.
Fixed some assorted escaping bugs.
* Updated test output.
* Added decode test.
* Stop using hex output, to get around 'greedy' aspect. Use octal instead.
* Added HostHostCallable
Small changes to use ArtifactDesc/Info instead of large switches.
* Fix C++ emit to handle arbitrary function export.
* Add options handling for callable without an output being specified.
* Can compile with COM interface. Added example using com interface.
* Use the IR Ptr type instead of hack in C++ emit for interfaces.
* Fix issue with outputting the COM call when ptr is used.
* Fix crash issue on compilation failure.
* Add support for __global.
* Added `ActualGlobalRate`
Added special handling around globals and COM interfaces.
Tested out in cpu-com-example.
* Fix typo in NodeBase.
* Support for accessing globals by name working.
* Check that actual global initialization is working.
* Refactor the com replacement such that it doesn't need a cache or do anything special with GlobalVar.
* Remove context.
Only create replacement if needed.
* Split out COM host-callable into a unit-test.
* host-callable com testing on C++and llvm.
* Comment around the COM ptr replacement.
* Disable com test on vs 32 bit.
Fix C++ prelude
* Disable 32 bit targets testing com host-callable.
* Use JSON parsing to locate VS version.
* Need platform detection in C++prelude.
* Fix com host callable test for LLVM.
* Work around for not being able to include "targetConditionals.h"
27 files changed, 855 insertions, 118 deletions
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj index 87cd8e9ec..2d1e50ab1 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj @@ -273,6 +273,7 @@ <ItemGroup>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-byte-encode.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-chunked-list.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-com-host-callable.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-command-line-args.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-compression.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
@@ -292,6 +293,9 @@ <ClCompile Include="..\..\..\tools\unit-test\slang-unit-test.cpp" />
</ItemGroup>
<ItemGroup>
+ <None Include="..\..\..\tools\slang-unit-test\unit-test-com-host-callable.slang" />
+ </ItemGroup>
+ <ItemGroup>
<ProjectReference Include="..\lz4\lz4.vcxproj">
<Project>{E1EC8075-823E-46E5-BC38-C124CCCDF878}</Project>
</ProjectReference>
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters index 4a4e7bce9..5cdfc14a3 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters @@ -20,6 +20,9 @@ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-chunked-list.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-com-host-callable.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-command-line-args.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -72,4 +75,9 @@ <Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
+ <ItemGroup>
+ <None Include="..\..\..\tools\slang-unit-test\unit-test-com-host-callable.slang">
+ <Filter>Source Files</Filter>
+ </None>
+ </ItemGroup>
</Project>
\ No newline at end of file diff --git a/examples/cpu-com-example/main.cpp b/examples/cpu-com-example/main.cpp index 2e9eeed79..dedbf69f5 100644 --- a/examples/cpu-com-example/main.cpp +++ b/examples/cpu-com-example/main.cpp @@ -7,7 +7,6 @@ #include <slang-com-ptr.h> #include <slang-com-helper.h> - // This includes a useful small function for setting up the prelude (described more further below). #include "../../source/core/slang-test-tool-util.h" @@ -21,8 +20,9 @@ using namespace Slang; class IDoThings : public ISlangUnknown { public: - virtual int SLANG_MCALL doThing(int a, int b) = 0; - virtual int SLANG_MCALL calcHash(const char* in) = 0; + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) = 0; + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) = 0; + virtual SLANG_NO_THROW void SLANG_MCALL printMessage(const char* in) = 0; }; static int _calcHash(const char* in) @@ -36,7 +36,7 @@ static int _calcHash(const char* in) return hash; } -class DoThings :public IDoThings +class DoThings : public IDoThings { public: // We don't need queryInterface for this impl, or ref counting @@ -45,17 +45,21 @@ public: virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; } // IDoThings - virtual int SLANG_MCALL doThing(int a, int b) SLANG_OVERRIDE { return a + b + 1; } - virtual int SLANG_MCALL calcHash(const char* in) SLANG_OVERRIDE { return (int)_calcHash(in); } + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) SLANG_OVERRIDE { return a + b + 1; } + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) SLANG_OVERRIDE { return (int)_calcHash(in); } + virtual SLANG_NO_THROW void SLANG_MCALL printMessage(const char* in) SLANG_OVERRIDE { printf("%s\n", in); } }; static SlangResult _innerMain(int argc, char** argv) { + // NOTE! This example only works if `slang-llvm` or a C++ compiler that Slang supports is available. + // Create the session ComPtr<slang::IGlobalSession> slangSession; slangSession.attach(spCreateSession(NULL)); // Set up the prelude + // NOTE: This isn't strictly necessary, as preludes are embedded in the binary. TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], slangSession); // Create a compile request @@ -93,6 +97,19 @@ static SlangResult _innerMain(int argc, char** argv) ComPtr<ISlangSharedLibrary> sharedLibrary; SLANG_RETURN_ON_FAIL(request->getTargetHostCallable(0, sharedLibrary.writeRef())); + DoThings doThings; + + { + auto doThingsPtr = (IDoThings**)sharedLibrary->findSymbolAddressByName("globalDoThings"); + if (!doThingsPtr) + { + return SLANG_FAIL; + } + // Set the global interface + *doThingsPtr = &doThings; + } + + // Test a free function { typedef const char* (*Func)(const char*); Func func = (Func)sharedLibrary->findFuncByName("getString"); @@ -107,25 +124,34 @@ static SlangResult _innerMain(int argc, char** argv) SLANG_ASSERT(text == returnedText); } - { - typedef int (*Func)(const char* text, IDoThings* doThings); + // Test hash + { + typedef int (*Func)(const char* text); Func func = (Func)sharedLibrary->findFuncByName("calcHash"); - if (!func) { return SLANG_FAIL; } - DoThings doThings; - String text("Hello"); - - const int hash = func(text.getBuffer(), &doThings); - + const int hash = func(text.getBuffer()); SLANG_ASSERT(hash == _calcHash(text.getBuffer())); } - + + // Test printing + { + typedef void (*Func)(const char* text); + + Func func = (Func)sharedLibrary->findFuncByName("printMessage"); + + if (!func) + { + return SLANG_FAIL; + } + func("Hello World!"); + } + return SLANG_OK; } diff --git a/examples/cpu-com-example/shader.slang b/examples/cpu-com-example/shader.slang index b0fe259be..44b1f5b81 100644 --- a/examples/cpu-com-example/shader.slang +++ b/examples/cpu-com-example/shader.slang @@ -1,6 +1,8 @@ // shader.slang -// Example of using 'NativeString' +// Example using 'NativeString' and COM interface + +public __global __extern_cpp IDoThings globalDoThings; public __extern_cpp NativeString getString(NativeString in) { @@ -12,10 +14,15 @@ interface IDoThings { int doThing(int a, int b); int calcHash(NativeString in); + void printMessage(NativeString nativeString); } -public __extern_cpp int calcHash(NativeString text, IDoThings doThings) +public __extern_cpp int calcHash(NativeString text) { - return doThings.calcHash(text); + return globalDoThings.calcHash(text); } +public __extern_cpp void printMessage(NativeString text) +{ + return globalDoThings.printMessage(text); +} diff --git a/prelude/slang-cpp-prelude.h b/prelude/slang-cpp-prelude.h index db83222a6..ff6bb8f6f 100644 --- a/prelude/slang-cpp-prelude.h +++ b/prelude/slang-cpp-prelude.h @@ -99,7 +99,109 @@ Any compilers not detected by the above logic are now now explicitly zeroed out. # endif #endif /* SLANG_COMPILER */ +/* +The following section attempts to detect the target platform being compiled for. + +If an application defines `SLANG_PLATFORM` before including this header, +they take responsibility for setting any compiler-dependent macros +used later in the file. + +Most applications should not need to touch this section. +*/ +#ifndef SLANG_PLATFORM +# define SLANG_PLATFORM +/** +Operating system defines, see http://sourceforge.net/p/predef/wiki/OperatingSystems/ +*/ +# if defined(WINAPI_FAMILY) && WINAPI_FAMILY == WINAPI_PARTITION_APP +# define SLANG_WINRT 1 /* Windows Runtime, either on Windows RT or Windows 8 */ +# elif defined(XBOXONE) +# define SLANG_XBOXONE 1 +# elif defined(_WIN64) /* note: XBOXONE implies _WIN64 */ +# define SLANG_WIN64 1 +# elif defined(_M_PPC) +# define SLANG_X360 1 +# elif defined(_WIN32) /* note: _M_PPC implies _WIN32 */ +# define SLANG_WIN32 1 +# elif defined(__ANDROID__) +# define SLANG_ANDROID 1 +# elif defined(__linux__) || defined(__CYGWIN__) /* note: __ANDROID__ implies __linux__ */ +# define SLANG_LINUX 1 +# elif defined(__APPLE__) && !defined(SLANG_LLVM) +# include "TargetConditionals.h" +# if TARGET_OS_MAC +# define SLANG_OSX 1 +# else +# define SLANG_IOS 1 +# endif +# elif defined(__APPLE__) +// On `slang-llvm` we can't inclue "TargetConditionals.h" in general, so for now assume its OSX. +# define SLANG_OSX 1 +# elif defined(__CELLOS_LV2__) +# define SLANG_PS3 1 +# elif defined(__ORBIS__) +# define SLANG_PS4 1 +# elif defined(__SNC__) && defined(__arm__) +# define SLANG_PSP2 1 +# elif defined(__ghs__) +# define SLANG_WIIU 1 +# else +# error "unknown target platform" +# endif + + +/* +Any platforms not detected by the above logic are now now explicitly zeroed out. +*/ +# ifndef SLANG_WINRT +# define SLANG_WINRT 0 +# endif +# ifndef SLANG_XBOXONE +# define SLANG_XBOXONE 0 +# endif +# ifndef SLANG_WIN64 +# define SLANG_WIN64 0 +# endif +# ifndef SLANG_X360 +# define SLANG_X360 0 +# endif +# ifndef SLANG_WIN32 +# define SLANG_WIN32 0 +# endif +# ifndef SLANG_ANDROID +# define SLANG_ANDROID 0 +# endif +# ifndef SLANG_LINUX +# define SLANG_LINUX 0 +# endif +# ifndef SLANG_IOS +# define SLANG_IOS 0 +# endif +# ifndef SLANG_OSX +# define SLANG_OSX 0 +# endif +# ifndef SLANG_PS3 +# define SLANG_PS3 0 +# endif +# ifndef SLANG_PS4 +# define SLANG_PS4 0 +# endif +# ifndef SLANG_PSP2 +# define SLANG_PSP2 0 +# endif +# ifndef SLANG_WIIU +# define SLANG_WIIU 0 +# endif +#endif /* SLANG_PLATFORM */ + + +/* Shorthands for "families" of compilers/platforms */ #define SLANG_GCC_FAMILY (SLANG_CLANG || SLANG_SNC || SLANG_GHS || SLANG_GCC) +#define SLANG_WINDOWS_FAMILY (SLANG_WINRT || SLANG_WIN32 || SLANG_WIN64) +#define SLANG_MICROSOFT_FAMILY (SLANG_XBOXONE || SLANG_X360 || SLANG_WINDOWS_FAMILY) +#define SLANG_LINUX_FAMILY (SLANG_LINUX || SLANG_ANDROID) +#define SLANG_APPLE_FAMILY (SLANG_IOS || SLANG_OSX) /* equivalent to #if __APPLE__ */ +#define SLANG_UNIX_FAMILY (SLANG_LINUX_FAMILY || SLANG_APPLE_FAMILY) /* shortcut for unix/posix platforms */ // GCC Specific #if SLANG_GCC_FAMILY @@ -147,6 +249,9 @@ convention for interface methods. # define SLANG_MCALL SLANG_STDCALL #endif + + + struct SlangUUID { uint32_t data1; diff --git a/prelude/slang-llvm.h b/prelude/slang-llvm.h index 22966ead0..08d6a74dd 100644 --- a/prelude/slang-llvm.h +++ b/prelude/slang-llvm.h @@ -1,6 +1,11 @@ #ifndef SLANG_LLVM_H #define SLANG_LLVM_H +// TODO(JS): +// Disable exception declspecs, as not supported on LLVM without some extra options. +// We could enable with `-fms-extensions` +#define SLANG_DISABLE_EXCEPTIONS 1 + #ifndef SLANG_PRELUDE_ASSERT # ifdef DEBUG extern "C" void assertFailure(const char* msg); @@ -1003,10 +1003,8 @@ extern "C" @param name The name of the function @return The function pointer related to the name or nullptr if not found */ - inline SlangFuncPtr SLANG_MCALL findFuncByName(char const* name) - { - return reinterpret_cast<SlangFuncPtr>(findSymbolAddressByName(name)); - } + SLANG_FORCE_INLINE SlangFuncPtr findFuncByName(char const* name) { return (SlangFuncPtr)findSymbolAddressByName(name); } + /** Get a symbol by name. If the library is unloaded will only return nullptr. @param name The name of the symbol @return The pointer related to the name or nullptr if not found @@ -1062,7 +1060,7 @@ extern "C" cache source contents internally. It is also used for #pragma once functionality. A *requirement* is for any implementation is that two paths can only return the same uniqueIdentity if the - contents of the two files are *identical*h. If an implementation breaks this constraint it can produce incorrect compilation. + contents of the two files are *identical*. If an implementation breaks this constraint it can produce incorrect compilation. If an implementation cannot *strictly* identify *the same* files, this will only have an effect on #pragma once behavior. The string for the uniqueIdentity is held zero terminated in the ISlangBlob of outUniqueIdentity. diff --git a/source/compiler-core/windows/slang-win-visual-studio-util.cpp b/source/compiler-core/windows/slang-win-visual-studio-util.cpp index e1a0f6109..9b175f308 100644 --- a/source/compiler-core/windows/slang-win-visual-studio-util.cpp +++ b/source/compiler-core/windows/slang-win-visual-studio-util.cpp @@ -4,6 +4,9 @@ #include "../../core/slang-process-util.h" #include "../../core/slang-string-util.h" +#include "../slang-json-parser.h" +#include "../slang-json-value.h" + #include "../slang-visual-studio-compiler-util.h" #ifdef _WIN32 @@ -82,6 +85,7 @@ VersionInfo _makeVersionInfo(const char* name, int high, int dot = 0) return info; } +// https://en.wikipedia.org/wiki/Microsoft_Visual_Studio static const VersionInfo s_versionInfos[] = { _makeVersionInfo("VS 2005", 8), @@ -92,6 +96,7 @@ static const VersionInfo s_versionInfos[] = _makeVersionInfo("VS 2015", 14), _makeVersionInfo("VS 2017", 15), _makeVersionInfo("VS 2019", 16), + _makeVersionInfo("VS 2022", 17), }; // When trying to figure out how this stuff works by running regedit - care is needed, @@ -135,7 +140,7 @@ static int _getRegistryKeyIndex(Version version) /* static */WinVisualStudioUtil::Version WinVisualStudioUtil::getCompiledVersion() { // Get the version of visual studio used to compile this source - const uint32_t version = _MSC_VER; + uint32_t version = _MSC_VER; switch (version) { @@ -156,27 +161,51 @@ static int _getRegistryKeyIndex(Version version) case 1916: { return _makeVersion(15); - } - case 1920: - { - return _makeVersion(16); - } - default: - { - int lastKnownVersion = 1920; - if (version > lastKnownVersion) - { - // Its an unknown newer version - return Version::Future; - } - break; - } + } + default: break; + } + + // Seems like versions go in runs of 10 at this point + // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macros?view=msvc-170 + + if (version >= 1920 && version < 1930) + { + return _makeVersion(16); + } + else if (version >= 1930 && version < 1940) + { + // We are going to assume it's a run of t0 + return _makeVersion(17); + } + else if (version >= 1940) + { + // Its an unknown newer version + return Version::Future; } // Unknown version return Version::Unknown; } +static SlangResult _parseJson(const String& contents, DiagnosticSink* sink, JSONContainer* container, JSONValue& outRoot) +{ + auto sourceManager = sink->getSourceManager(); + + SourceFile* sourceFile = sourceManager->createSourceFileWithString(PathInfo::makeUnknown(), contents); + SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr, SourceLoc()); + + JSONLexer lexer; + lexer.init(sourceView, sink); + + JSONBuilder builder(container); + + JSONParser parser; + SLANG_RETURN_ON_FAIL(parser.parse(&lexer, sourceView, &builder, sink)); + + outRoot = builder.getRootValue(); + return SLANG_OK; +} + static SlangResult _find(int versionIndex, WinVisualStudioUtil::VersionPath& outPath) { const auto& versionInfo = s_versionInfos[versionIndex]; @@ -202,22 +231,83 @@ static SlangResult _find(int versionIndex, WinVisualStudioUtil::VersionPath& out cmd.setExecutableLocation(ExecutableLocation(vswherePath)); + const auto desc = WinVisualStudioUtil::getDesc(version); + StringBuilder versionName; WinVisualStudioUtil::append(version, versionName); - String args[] = { "-version", versionName, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", "-property", "installationPath" }; + // Using -? we can find out vswhere options. + + // Previous args - works but returns multiple versions, without listing what version is associated with which path + // or the order. + //String args[] = { "-version", versionName, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", "-json", "-property", "installationPath", "-property", "installationVersion" }; + + // Use JSON parsing, we can verify the versions for a path, otherwise multiple versions are returned + // not just the version specified. The ordering isn't defined (and -sort doesn't appear to work) + String args[] = { "-version", versionName, "-format", "json", "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64"}; + cmd.addArgs(args, SLANG_COUNT_OF(args)); + SourceManager sourceManager; + sourceManager.initialize(nullptr, nullptr); + DiagnosticSink sink(&sourceManager, nullptr); + + RefPtr<JSONContainer> container = new JSONContainer(&sourceManager); + ExecuteResult exeRes; if (SLANG_SUCCEEDED(ProcessUtil::execute(cmd, exeRes))) { - // We need to chopoff CR/LF if there is one - List<UnownedStringSlice> lines; - StringUtil::calcLines(exeRes.standardOutput.getUnownedSlice(), lines); + JSONValue jsonRoot; + SLANG_RETURN_ON_FAIL(_parseJson(exeRes.standardOutput, &sink, container, jsonRoot)); + + // Search through the array... + if (jsonRoot.getKind() != JSONValue::Kind::Array) + { + return SLANG_FAIL; + } + + auto arr = container->getArray(jsonRoot); + + const auto pathKey = container->getKey(UnownedStringSlice::fromLiteral("installationPath")); + const auto versionKey = container->getKey(UnownedStringSlice::fromLiteral("installationVersion")); - if (lines.getCount()) + for (auto elem : arr) { - outPath.vcvarsPath = lines[0]; + // Get the path and the name + if (elem.getKind() != JSONValue::Kind::Object) + { + continue; + } + + auto pathJsonValue = container->findObjectValue(elem, pathKey); + auto versionJsonValue = container->findObjectValue(elem, versionKey); + + if (!pathJsonValue.isValid() || !versionJsonValue.isValid()) + { + continue; + } + + auto pathString = container->getString(pathJsonValue); + auto versionString = container->getString(versionJsonValue).trim(); + + // If the versionString matches + List<UnownedStringSlice> versionSlices; + StringUtil::split(versionString, '.', versionSlices); + + if (versionSlices.getCount() <= 0) + { + continue; + } + + Int versionValue; + SLANG_RETURN_ON_FAIL(StringUtil::parseInt(versionSlices[0], versionValue)); + + if (versionValue != desc.majorVersion) + { + continue; + } + + outPath.vcvarsPath = pathString; outPath.vcvarsPath.append("\\VC\\Auxiliary\\Build\\"); return SLANG_OK; } diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index eef7d6908..470e5f6d9 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -381,6 +381,13 @@ namespace Slang else SLANG_ASSERT_FAILURE("Inconsistent find result returned. This is a bug in Dictionary implementation."); } + void Set(const TKey& key, const TValue& value) + { + if (auto ptr = TryGetValueOrAdd(key, value)) + { + *ptr = value; + } + } template<typename KeyType> bool ContainsKey(const KeyType& key) const diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index c74edb938..3126aab71 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -33,6 +33,9 @@ class NodeBase /// correctly constructed (through ASTBuilder) NodeBase derived class. /// The actual type is set when constructed on the ASTBuilder. ASTNodeType astNodeType = ASTNodeType(-1); + + // Handy when debugging, shouldn't be checked in though! + // virtual ~NodeBase() {} }; // Casting of NodeBase diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 4522b7148..012c74377 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -31,6 +31,10 @@ class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)}; class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; +// An 'ActualGlobal' is a global that is output as a normal global in CPU code. +// Globals in HLSL/Slang are constant state passed into kernel execution +class ActualGlobalModifier : public Modifier { SLANG_AST_CLASS(ActualGlobalModifier)}; + /// A modifier that indicates an `InheritanceDecl` should be ignored during name lookup (and related checks). class IgnoreForLookupModifier : public Modifier { SLANG_AST_CLASS(IgnoreForLookupModifier) }; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a2311b186..bb762c1c6 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -267,6 +267,9 @@ namespace Slang /// Is `decl` a global shader parameter declaration? bool isGlobalShaderParameter(VarDeclBase* decl) { + // If it's an *actual* global it is not a global shader parameter + if (decl->hasModifier<ActualGlobalModifier>()) { return false; } + // A global shader parameter must be declared at global or namespace // scope, so that it has a single definition across the module. // diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index ff1f660a1..148b0205b 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2887,7 +2887,7 @@ namespace Slang void addTransition(CodeGenTarget source, CodeGenTarget target, PassThroughMode compiler) { SLANG_ASSERT(source != target); - m_map.Add(Pair{ source, target }, compiler); + m_map.Set(Pair{ source, target }, compiler); } bool hasTransition(CodeGenTarget source, CodeGenTarget target) const { diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index ddc8b24ed..c23135c70 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -2547,7 +2547,50 @@ void CPPSourceEmitter::emitPreModuleImpl() } } -/* virtual */void CPPSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) + +void CPPSourceEmitter::emitGlobalInstImpl(IRInst* inst) +{ + if (as<IRGlobalVar>(inst) && inst->findDecoration<IRExternCppDecoration>()) + { + // JS: + // Turns out just doing extern "C" means something different on a variable + // So we need to wrap in extern "C" { } + m_writer->emit("extern \"C\" {\n"); + Super::emitGlobalInstImpl(inst); + m_writer->emit("\n}\n"); + } + else + { + Super::emitGlobalInstImpl(inst); + } +} + +static bool _isExported(IRInst* inst) +{ + for (auto decoration : inst->getDecorations()) + { + const auto op = decoration->getOp(); + if (op == kIROp_PublicDecoration || + op == kIROp_HLSLExportDecoration) + { + return true; + } + } + return false; +} + +void CPPSourceEmitter::emitVarDecorationsImpl(IRInst* inst) +{ + if (as<IRGlobalVar>(inst) && _isExported(inst)) + { + m_writer->emit("SLANG_PRELUDE_SHARED_LIB_EXPORT\n"); + } + + Super::emitVarDecorationsImpl(inst); +} + + +void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst) { // Specially handle export, as we don't want to emit it multiple times if (getTargetReq()->isWholeProgramRequest()) @@ -2556,7 +2599,7 @@ void CPPSourceEmitter::emitPreModuleImpl() bool isExported = false; // If public/export made it externally visible - for (auto decoration : func->getDecorations()) + for (auto decoration : inst->getDecorations()) { const auto op = decoration->getOp(); if (op == kIROp_ExternCppDecoration) @@ -2581,6 +2624,11 @@ void CPPSourceEmitter::emitPreModuleImpl() m_writer->emit("extern \"C\"\n"); } } +} + +/* virtual */void CPPSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) +{ + _maybeEmitExportLike(func); // Use the default for others Super::emitFuncDecorationsImpl(func); diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index f5ba35933..6199c33f2 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -75,7 +75,9 @@ protected: virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; virtual void emitFuncDecorationsImpl(IRFunc* func) SLANG_OVERRIDE; - + virtual void emitVarDecorationsImpl(IRInst* var) SLANG_OVERRIDE; + virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; + virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount); // Replaceable for classes derived from CPPSourceEmitter @@ -130,6 +132,9 @@ protected: // of all the witness table objects in `pendingWitnessTableDefinitions`. void _emitWitnessTableDefinitions(); + /// Maybe emits 'export' (such that visible outside binary/dll) and `extern "C"` naming + void _maybeEmitExportLike(IRInst* inst); + HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount); static bool _isVariable(IROp op); @@ -137,7 +142,6 @@ protected: Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap; Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap; - IRTypeSet m_typeSet; RefPtr<HLSLIntrinsicOpLookup> m_opLookup; HLSLIntrinsicSet m_intrinsicSet; diff --git a/source/slang/slang-ir-com-interface.cpp b/source/slang/slang-ir-com-interface.cpp index 009d2314d..899596209 100644 --- a/source/slang/slang-ir-com-interface.cpp +++ b/source/slang/slang-ir-com-interface.cpp @@ -7,92 +7,105 @@ namespace Slang { -struct ComInterfaceLoweringContext +static bool _canReplace(IRUse* use) { - IRModule* module; - DiagnosticSink* diagnosticSink; - - ArtifactStyle artifactStyle; - - SharedIRBuilder sharedBuilder; - - void replaceTypeUses(IRInst* inst, IRInst* newValue) + switch (use->getUser()->getOp()) { - List<IRUse*> uses; - for (auto use = inst->firstUse; use; use = use->nextUse) + case kIROp_WitnessTableIDType: + case kIROp_WitnessTableType: + case kIROp_RTTIPointerType: + case kIROp_RTTIHandleType: { - uses.add(use); + // Don't replace + return false; } - for (auto use : uses) + case kIROp_ThisType: { - switch (use->getUser()->getOp()) - { - case kIROp_WitnessTableIDType: - case kIROp_WitnessTableType: - case kIROp_ThisType: - case kIROp_RTTIPointerType: - case kIROp_RTTIHandleType: - case kIROp_ComPtrType: - case kIROp_PtrType: - continue; - default: - break; - } - use->set(newValue); + // Appears replacable. + break; } + case kIROp_ComPtrType: + case kIROp_PtrType: + { + // We can have ** and ComPtr<T>*. + // If it's a pointer type it could be because it is a global. + break; + } + default: break; } + return true; +} - IRType* processInterfaceType(IRInterfaceType* type) +void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + // Find all of the COM interfaces + List<IRInterfaceType*> comInterfaces; + for (auto child : module->getGlobalInsts()) { - if (!type->findDecoration<IRComInterfaceDecoration>()) - return nullptr; - + auto intf = as<IRInterfaceType>(child); + if (intf && intf->findDecoration<IRComInterfaceDecoration>()) + { + comInterfaces.add(intf); + } + } + + // For all interfaces found replace uses + { + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + IRBuilder builder(sharedBuilder); builder.setInsertInto(module->getModuleInst()); - IRType* result = (artifactStyle == ArtifactStyle::Kernel) ? - static_cast<IRType*>(builder.getPtrType(type)) : - static_cast<IRType*>(builder.getComPtrType(type)); + List<IRUse*> uses; - replaceTypeUses(type, result); - return result; - } + for (auto comIntf : comInterfaces) + { + uses.clear(); - void processThisType(IRThisType* type) - { - auto comPtrType = processInterfaceType(as<IRInterfaceType>(type->getConstraintType())); - if (!comPtrType) - return; - replaceTypeUses(type, comPtrType); - } + // Find all of the uses *before* doing any replacement + // Otherwise we end up replacing the replacement leading + // to it pointing to itself. + for (auto use = comIntf->firstUse; use; use = use->nextUse) + { + // Only store off uses where replacement can be made + if (_canReplace(use)) + { + uses.add(use); + } + } - void processModule() - { - for (auto child : module->getGlobalInsts()) - { - switch (child->getOp()) + // If there are no uses that can be replaced, then we don't need + // to create a replacement result + if (uses.getCount() <= 0) { - case kIROp_InterfaceType: - processInterfaceType(as<IRInterfaceType>(child)); - break; - case kIROp_ThisType: - processThisType(as<IRThisType>(child)); - break; - default: - break; + continue; + } + + // NOTE! The following code relies on the fact that the builder + // *doesn't* dedup in general, and in particular doesn't ptr types. + // This allows the creation a 'new' pointer type, and subsequent replacment all old uses, + // leading to a `IInterface*` becoming `IInterface**`. + // + + // TODO(JS): This is a temporary fix, in that whether kernel or not + // shouldn't control the ptr type in general + // It's necessary here though because Kernel doesn't have ComPtr<> + // so has to be a raw pointer + IRType* result = (artifactStyle == ArtifactStyle::Host) ? + static_cast<IRType*>(builder.getComPtrType(comIntf)) : + static_cast<IRType*>(builder.getPtrType(comIntf)); + + // Go through replacing all of the replacable uses + for (auto use : uses) + { + // Do the replacement + use->set(result); } } } -}; - -void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, DiagnosticSink* sink) -{ - ComInterfaceLoweringContext context; - context.module = module; - context.diagnosticSink = sink; - context.artifactStyle = artifactStyle; - context.sharedBuilder.init(module); - return context.processModule(); } } diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index 6ab0d68f2..6e88c4cd7 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -50,6 +50,12 @@ struct IntroduceExplicitGlobalContextPass // auto globalVar = cast<IRGlobalVar>(inst); + // Actual globals don't need to be moved to the context + if (as<IRActualGlobalRate>(globalVar->getRate())) + { + continue; + } + // One important exception is that CUDA *does* support // global variables with the `__shared__` qualifer, with // semantics that exactly match HLSL/Slang `groupshared`. diff --git a/source/slang/slang-ir-explicit-global-init.cpp b/source/slang/slang-ir-explicit-global-init.cpp index 07397902e..94c065514 100644 --- a/source/slang/slang-ir-explicit-global-init.cpp +++ b/source/slang/slang-ir-explicit-global-init.cpp @@ -87,6 +87,12 @@ struct MoveGlobalVarInitializationToEntryPointsPass if(!globalVar) continue; + // If it's an `Actual Global` we don't want to move initialization + if (as<IRActualGlobalRate>(globalVar->getRate())) + { + continue; + } + auto firstBlock = globalVar->getFirstBlock(); if(!firstBlock) continue; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b77621720..6547d949e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -79,6 +79,7 @@ INST(Nop, nop, 0, 0) /* Rate */ INST(ConstExprRate, ConstExpr, 0, 0) INST(GroupSharedRate, GroupShared, 0, 0) + INST(ActualGlobalRate, ActualGlobalRate, 0, 0) INST_RANGE(Rate, ConstExprRate, GroupSharedRate) INST(RateQualifiedType, RateQualified, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9c1569ca8..5e8e11f84 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2235,6 +2235,7 @@ public: IRConstExprRate* getConstExprRate(); IRGroupSharedRate* getGroupSharedRate(); + IRActualGlobalRate* getActualGlobalRate(); IRRateQualifiedType* getRateQualifiedType( IRRate* rate, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ee84cce73..562d0ea1a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2795,6 +2795,10 @@ namespace Slang { return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); } + IRActualGlobalRate* IRBuilder::getActualGlobalRate() + { + return (IRActualGlobalRate*)getType(kIROp_ActualGlobalRate); + } IRRateQualifiedType* IRBuilder::getRateQualifiedType( IRRate* rate, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 7a9a1ebee..403376dca 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1226,6 +1226,7 @@ SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase) SIMPLE_IR_PARENT_TYPE(Rate, Type) SIMPLE_IR_TYPE(ConstExprRate, Rate) SIMPLE_IR_TYPE(GroupSharedRate, Rate) +SIMPLE_IR_TYPE(ActualGlobalRate, Rate) struct IRRateQualifiedType : IRType { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5ffb1bf33..d175b69dd 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2034,6 +2034,12 @@ void maybeSetRate( builder->getGroupSharedRate(), inst->getFullType())); } + else if (decl->hasModifier<ActualGlobalModifier>()) + { + inst->setFullType(builder->getRateQualifiedType( + builder->getActualGlobalRate(), + inst->getFullType())); + } } static String getNameForNameHint( diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1c32499ef..b2179c1af 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6339,6 +6339,8 @@ namespace Slang _makeParseModifier("instance", InstanceModifier::kReflectClassInfo), _makeParseModifier("__builtin", BuiltinModifier::kReflectClassInfo), + _makeParseModifier("__global", ActualGlobalModifier::kReflectClassInfo), + _makeParseModifier("inline", InlineModifier::kReflectClassInfo), _makeParseModifier("public", PublicModifier::kReflectClassInfo), _makeParseModifier("require", RequireModifier::kReflectClassInfo), diff --git a/tools/slang-unit-test/unit-test-com-host-callable.cpp b/tools/slang-unit-test/unit-test-com-host-callable.cpp new file mode 100644 index 000000000..57209fbde --- /dev/null +++ b/tools/slang-unit-test/unit-test-com-host-callable.cpp @@ -0,0 +1,335 @@ +// unit-test-com-host-callable.cpp + +#include "../../source/core/slang-byte-encode-util.h" + +#include <stdio.h> +#include <stdlib.h> + +#include "tools/unit-test/slang-unit-test.h" + +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../../slang-com-ptr.h" + +#include "../../source/core/slang-list.h" + +namespace { // anonymous + +// Slang namespace is used for elements support code (like core) which we use here +// for ComPtr<> and TestToolUtil +using namespace Slang; + +// For the moment we have to explicitly write the Slang COM interface in C++ code. It *MUST* match +// the interface in the slang source +// As it stands all interfaces need to derive from ISlangUnknown (or IUnknown). +class IDoThings : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) = 0; + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) = 0; +}; + +class ICountGood : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW int SLANG_MCALL nextCount() = 0; +}; + +static int _calcHash(const char* in) +{ + int hash = 0; + for (; *in; ++in) + { + // A very poor hash function + hash = hash * 13 + *in; + } + return hash; +} + +class DoThings : public IDoThings +{ +public: + // We don't need queryInterface for this impl, or ref counting + virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; } + + // IDoThings + virtual SLANG_NO_THROW int SLANG_MCALL doThing(int a, int b) SLANG_OVERRIDE { return a + b + 1; } + virtual SLANG_NO_THROW int SLANG_MCALL calcHash(const char* in) SLANG_OVERRIDE { return (int)_calcHash(in); } +}; + +class CountGood : public ICountGood +{ +public: + // We don't need queryInterface for this impl, or ref counting + virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE { return SLANG_E_NOT_IMPLEMENTED; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return 1; } + + // ICountGood + virtual SLANG_NO_THROW int SLANG_MCALL nextCount() SLANG_OVERRIDE { return m_count++; } + + int m_count = 0; +}; + +struct ComTestContext +{ + ComTestContext(UnitTestContext* context): + m_unitTestContext(context) + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + m_defaultCppCompiler = slangSession->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP); + + m_hostHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE); + m_shaderHostCallableCompiler = slangSession->getDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE); + + } + + SlangResult runTests() + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + // TODO(JS): + // Care is needed around this in normal testing. `slang-llvm` is whatever was asked for for when premake was built + // when the target is specified. Otherwise it is the `default` which is typically 64 bit during development. + // + // On CI we should be okay, because it should download the correct `slang-llvm` for the build (as it packages up with it). + // But for normal development, that can easily not be the case (for example changing to 32 bit build in VS is a problem). + // + // Make sure to run + // + // ``` + // premake --arch=x86 --deps=true + // ``` + // + // for the actual target/arch(!) + + const bool hasLlvm = SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(SLANG_PASS_THROUGH_LLVM)); + + SlangPassThrough cppCompiler = SLANG_PASS_THROUGH_NONE; + + { + const SlangPassThrough cppCompilers[] = + { + SLANG_PASS_THROUGH_VISUAL_STUDIO, + SLANG_PASS_THROUGH_GCC, + SLANG_PASS_THROUGH_CLANG, + }; + // Do we have a C++ compiler + for (const auto compiler : cppCompilers) + { + if (SLANG_SUCCEEDED(slangSession->checkPassThroughSupport(compiler))) + { + cppCompiler = compiler; + break; + } + } + } + + // If we have an *actual* C++ compile rtest on that first + if (cppCompiler != SLANG_PASS_THROUGH_NONE) + { + slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, cppCompiler); + + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, cppCompiler); + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, cppCompiler); + + SLANG_RETURN_ON_FAIL(_runTest()); + } + + // Reset the compiler that's used for host-callable + _reset(); + + // If we have Llvm it is the default host callable compiler + if (hasLlvm) + { + // Should run via slang-llvm + SLANG_RETURN_ON_FAIL(_runTest()); + } + + return SLANG_OK; + } + + void _reset() + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + slangSession->setDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP, m_defaultCppCompiler); + + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_SHADER_HOST_CALLABLE, m_shaderHostCallableCompiler); + slangSession->setDownstreamCompilerForTransition(SLANG_CPP_SOURCE, SLANG_HOST_HOST_CALLABLE, m_hostHostCallableCompiler); + } + + ~ComTestContext() + { + _reset(); + } + + SlangResult _runTest(); + + UnitTestContext* m_unitTestContext; + + SlangPassThrough m_defaultCppCompiler; + SlangPassThrough m_hostHostCallableCompiler; + SlangPassThrough m_shaderHostCallableCompiler; +}; + +SlangResult ComTestContext::_runTest() +{ + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + + // Create a compile request + Slang::ComPtr<slang::ICompileRequest> request; + SLANG_RETURN_ON_FAIL(slangSession->createCompileRequest(request.writeRef())); + + // We want to compile to 'HOST_CALLABLE' here such that we can execute the Slang code. + // + // Note that it is possible to use HOST_HOST_CALLABLE, but this currently only works with 'regular' C++ compilers + // not with `slang-llvm`. + const int targetIndex = request->addCodeGenTarget(SLANG_SHADER_HOST_CALLABLE); + + // Set the target flag to indicate that we want to compile all into a library. + request->setTargetFlags(targetIndex, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM); + + request->setOptimizationLevel(SLANG_OPTIMIZATION_LEVEL_NONE); + request->setDebugInfoLevel(SLANG_DEBUG_INFO_LEVEL_STANDARD); + + // Add the translation unit + const int translationUnitIndex = request->addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, nullptr); + + // Set the source file for the translation unit + request->addTranslationUnitSourceFile(translationUnitIndex, "tools/slang-unit-test/unit-test-com-host-callable.slang"); + + const SlangResult compileRes = request->compile(); + + // Even if there were no errors that forced compilation to fail, the + // compiler may have produced "diagnostic" output such as warnings. + // We will go ahead and print that output here. + // + if (auto diagnostics = request->getDiagnosticOutput()) + { + printf("%s", diagnostics); + } + + // Get the 'shared library' (note that this doesn't necessarily have to be implemented as a shared library + // it's just an interface to executable code). + ComPtr<ISlangSharedLibrary> sharedLibrary; + SLANG_RETURN_ON_FAIL(request->getTargetHostCallable(0, sharedLibrary.writeRef())); + + { + typedef const char* (*Func)(const char*); + Func func = (Func)sharedLibrary->findFuncByName("getString"); + + if (!func) + { + return SLANG_FAIL; + } + + String text = "Hello World!"; + String returnedText = func(text.getBuffer()); + + SLANG_CHECK(text == returnedText); + } + { + typedef int (*Func)(const char* text, IDoThings* doThings); + + Func func = (Func)sharedLibrary->findFuncByName("calcHash"); + + if (!func) + { + return SLANG_FAIL; + } + + DoThings doThings; + + String text("Hello"); + + const int hash = func(text.getBuffer(), &doThings); + + SLANG_CHECK(hash == _calcHash(text.getBuffer())); + } + + // Check accessing a global + { + typedef void (*SetFunc)(int v); + typedef int (*GetFunc)(); + + const auto setGlobal = (SetFunc)sharedLibrary->findFuncByName("setGlobal"); + const auto getGlobal = (GetFunc)sharedLibrary->findFuncByName("getGlobal"); + + if (setGlobal == nullptr || getGlobal == nullptr) + { + return SLANG_FAIL; + } + + // In the slang source it is set a default value + SLANG_CHECK(getGlobal() == 10); + + for (Index i = 0; i < 10; ++i) + { + setGlobal(int(i)); + SLANG_CHECK(getGlobal() == i); + } + } + + // Check using a global interface + { + + typedef void (*SetCounterFunc)(ICountGood* counter); + typedef int (*NextCountFunc)(); + + const auto setCounter = (SetCounterFunc)sharedLibrary->findFuncByName("setCounter"); + const auto nextCount = (NextCountFunc)sharedLibrary->findFuncByName("nextCount"); + + if (setCounter == nullptr || nextCount == nullptr) + { + return SLANG_FAIL; + } + + CountGood counter; + + ICountGood* counterIntf = &counter; + + setCounter(counterIntf); + + auto counterPtr = (ICountGood**)sharedLibrary->findSymbolAddressByName("globalCounter"); + SLANG_CHECK(counterPtr); + if (!counterPtr) + { + return SLANG_FAIL; + } + + for (Index i = 0; i < 10; ++i) + { + SLANG_CHECK(*counterPtr == &counter); + + const auto v = nextCount(); + SLANG_CHECK(v == i); + } + } + + return SLANG_OK; +} + +} // anonymous + +SLANG_UNIT_TEST(comHostCallable) +{ +#if SLANG_PTR_IS_32 && !SLANG_MICROSOFT_FAMILY + // TODO(JS): + // We can't currently run this test reliably on targets other than windows + // Visual Studio DownstreamCompiler has support for 32 bit builds + // Other targets generally build for the native environment which is almost always 64 bit, + // and it requires other features to build/test 32 bit binaries on such systems. + // + // So we disable for any 32 bit non MS target for now + return; +#endif + + ComTestContext context(unitTestContext); + + const auto result = context.runTests(); + + SLANG_CHECK(SLANG_SUCCEEDED(result)); +} diff --git a/tools/slang-unit-test/unit-test-com-host-callable.slang b/tools/slang-unit-test/unit-test-com-host-callable.slang new file mode 100644 index 000000000..b591904b3 --- /dev/null +++ b/tools/slang-unit-test/unit-test-com-host-callable.slang @@ -0,0 +1,50 @@ +// shader.slang + +// Example of using 'NativeString' + +public __extern_cpp NativeString getString(NativeString in) +{ + return in; +} + +public __extern_cpp __global int intGlobal = 10; + +public __extern_cpp void setGlobal(int v) +{ + intGlobal = v; +} + +public __extern_cpp int getGlobal() +{ + return intGlobal; +} + +[COM] +interface IDoThings +{ + int doThing(int a, int b); + int calcHash(NativeString in); +} + +[COM] +interface ICountGood +{ + int nextCount(); +} + +public __extern_cpp __global ICountGood globalCounter; + +public __extern_cpp void setCounter(ICountGood counter) +{ + globalCounter = counter; +} + +public __extern_cpp int nextCount() +{ + return globalCounter.nextCount(); +} + +public __extern_cpp int calcHash(NativeString text, IDoThings doThings) +{ + return doThings.calcHash(text); +} diff --git a/tools/unit-test/slang-unit-test.h b/tools/unit-test/slang-unit-test.h index 035f67e33..3c8bcaea6 100644 --- a/tools/unit-test/slang-unit-test.h +++ b/tools/unit-test/slang-unit-test.h @@ -67,7 +67,7 @@ typedef IUnitTestModule* (*UnitTestGetModuleFunc)(); void _##name##_impl(UnitTestContext* unitTestContext); \ void name(UnitTestContext* unitTestContext)\ {\ - try { _##name##_impl(unitTestContext); } catch (AbortTestException){} \ + try { _##name##_impl(unitTestContext); } catch (AbortTestException&){} \ }\ UnitTestRegisterHelper _##name##RegisterHelper(#name, name); \ void _##name##_impl(UnitTestContext* unitTestContext) |
