From af27de01532904508e6a630c213249e93fdd1c66 Mon Sep 17 00:00:00 2001 From: jarcherNV Date: Fri, 15 Aug 2025 13:16:09 -0700 Subject: Add static functions to create blobs from data (#8179) Add helper functions to create ISlangBlob and load module data from source. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- tools/slang-unit-test/unit-test-create-blob.cpp | 169 ++++++++++ tools/slang-unit-test/unit-test-ir-blob.cpp | 370 +++++++++++++++++++++ .../unit-test-load-module-from-source.cpp | 268 +++++++++++++++ 3 files changed, 807 insertions(+) create mode 100644 tools/slang-unit-test/unit-test-create-blob.cpp create mode 100644 tools/slang-unit-test/unit-test-ir-blob.cpp create mode 100644 tools/slang-unit-test/unit-test-load-module-from-source.cpp (limited to 'tools') diff --git a/tools/slang-unit-test/unit-test-create-blob.cpp b/tools/slang-unit-test/unit-test-create-blob.cpp new file mode 100644 index 000000000..edb136415 --- /dev/null +++ b/tools/slang-unit-test/unit-test-create-blob.cpp @@ -0,0 +1,169 @@ +// unit-test-create-blob.cpp + +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include +#include + +using namespace Slang; + +// Test the slang_createBlob function +SLANG_UNIT_TEST(createBlob) +{ + // Test 1: Basic functionality with valid string data + { + const char* testData = "Hello, World!"; + size_t dataSize = strlen(testData); + + ComPtr blob; + blob = slang_createBlob((const void*)testData, dataSize); + + SLANG_CHECK(blob != nullptr); + + if (blob) + { + SLANG_CHECK(blob->getBufferSize() == dataSize); + SLANG_CHECK(memcmp(blob->getBufferPointer(), testData, dataSize) == 0); + } + } + + // Test 2: Test with binary data (non-string) + { + const uint8_t binaryData[] = {0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD, 0xFC}; + size_t dataSize = sizeof(binaryData); + + ComPtr blob; + blob = slang_createBlob((const void*)binaryData, dataSize); + + SLANG_CHECK(blob != nullptr); + + if (blob) + { + SLANG_CHECK(blob->getBufferSize() == dataSize); + SLANG_CHECK(memcmp(blob->getBufferPointer(), binaryData, dataSize) == 0); + } + } + + // Test 3: Test with large data + { + const size_t largeSize = 1024 * 1024; // 1MB + char* largeData = new char[largeSize]; + + // Fill with pattern + for (size_t i = 0; i < largeSize; i++) + { + largeData[i] = (char)(i % 256); + } + + ComPtr blob; + blob = slang_createBlob((const void*)largeData, largeSize); + + SLANG_CHECK(blob != nullptr); + + if (blob) + { + SLANG_CHECK(blob->getBufferSize() == largeSize); + SLANG_CHECK(memcmp(blob->getBufferPointer(), largeData, largeSize) == 0); + } + + delete[] largeData; + } + + // Test 4: Test with null pointer and non-zero size (should fail) + { + char* testData = nullptr; + ComPtr blob; + blob = slang_createBlob((const void*)testData, 10); + + SLANG_CHECK(blob == nullptr); + } + + // Test 5: Test with null pointer and zero size (should fail) + { + char* testData = nullptr; + ComPtr blob; + blob = slang_createBlob((const void*)testData, 0); + + SLANG_CHECK(blob == nullptr); + } + + // Test 6: Test with valid pointer and zero size (should fail) + { + const char* testData = "test"; + + ComPtr blob; + blob = slang_createBlob((const void*)testData, 0); + + SLANG_CHECK(blob == nullptr); + } + + // Test 7: Test with void* version of the function + { + const char* testData = "Test void* version"; + size_t dataSize = strlen(testData); + + ComPtr blob; + blob = slang_createBlob((const void*)testData, dataSize); + + SLANG_CHECK(blob != nullptr); + + if (blob) + { + SLANG_CHECK(blob->getBufferSize() == dataSize); + SLANG_CHECK(memcmp(blob->getBufferPointer(), testData, dataSize) == 0); + } + } + + // Test 8: Test multiple blobs with same data + { + const char* testData = "Shared data"; + size_t dataSize = strlen(testData); + + ComPtr blob1; + ComPtr blob2; + ComPtr blob3; + blob1 = slang_createBlob((const void*)testData, dataSize); + blob2 = slang_createBlob((const void*)testData, dataSize); + blob3 = slang_createBlob((const void*)testData, dataSize); + + SLANG_CHECK(blob1 != nullptr); + SLANG_CHECK(blob2 != nullptr); + SLANG_CHECK(blob3 != nullptr); + + // All should have same content + SLANG_CHECK(blob1->getBufferSize() == dataSize); + SLANG_CHECK(blob2->getBufferSize() == dataSize); + SLANG_CHECK(blob3->getBufferSize() == dataSize); + + SLANG_CHECK(memcmp(blob1->getBufferPointer(), testData, dataSize) == 0); + SLANG_CHECK(memcmp(blob2->getBufferPointer(), testData, dataSize) == 0); + SLANG_CHECK(memcmp(blob3->getBufferPointer(), testData, dataSize) == 0); + } + + // Test 9: Test memory management (blob should be independent) + { + const char* testData = "Memory test"; + size_t dataSize = strlen(testData); + + ComPtr blob; + blob = slang_createBlob((const void*)testData, dataSize); + + // Modify original data - blob should remain unchanged + char* mutableData = new char[dataSize + 1]; + memcpy(mutableData, testData, dataSize + 1); + mutableData[0] = 'X'; // Change first character + + SLANG_CHECK(blob->getBufferSize() == dataSize); + SLANG_CHECK( + memcmp(blob->getBufferPointer(), testData, dataSize) == + 0); // Should still match original + SLANG_CHECK( + memcmp(blob->getBufferPointer(), mutableData, dataSize) != + 0); // Should not match modified + + delete[] mutableData; + } +} diff --git a/tools/slang-unit-test/unit-test-ir-blob.cpp b/tools/slang-unit-test/unit-test-ir-blob.cpp new file mode 100644 index 000000000..6f72afae5 --- /dev/null +++ b/tools/slang-unit-test/unit-test-ir-blob.cpp @@ -0,0 +1,370 @@ +// unit-test-ir-blob.cpp + +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include +#include + +using namespace Slang; + +// Test the slang_loadModuleFromIRBlob and slang_loadModuleInfoFromIRBlob functions +SLANG_UNIT_TEST(irBlob) +{ + // Test source code for creating IR data + const char* testModuleSource = R"( + module test_ir_module; + + public struct TestStruct { + float x, y, z; + } + + public void testFunction(TestStruct input) { + // Simple function + } + + public static const float PI = 3.14159; + )"; + + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + sessionDesc.targets = &targetDesc; + + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + // Create IR data by serializing a module + ComPtr irBlob; + { + ComPtr module; + ComPtr diagnostics; + + module = session->loadModuleFromSourceString( + "test_ir_module", + "test_ir_module.slang", + testModuleSource, + diagnostics.writeRef()); + + SLANG_CHECK(module != nullptr); + if (diagnostics && diagnostics->getBufferSize() > 0) + { + // Log diagnostics if any + printf( + "Module compilation diagnostics: %.*s\n", + (int)diagnostics->getBufferSize(), + (const char*)diagnostics->getBufferPointer()); + } + + // Serialize the module to create IR data + SLANG_CHECK(module->serialize(irBlob.writeRef()) == SLANG_OK); + SLANG_CHECK(irBlob != nullptr); + SLANG_CHECK(irBlob->getBufferSize() > 0); + } + + // Test 1: Test slang_loadModuleFromIRBlob with valid IR data + { + ComPtr loadedModule; + ComPtr diagnostics; + + loadedModule = slang_loadModuleFromIRBlob( + session, + "test_ir_module_loaded", + "test_ir_module_loaded.slang", + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(loadedModule != nullptr); + if (diagnostics && diagnostics->getBufferSize() > 0) + { + // Log diagnostics if any + printf( + "IR blob loading diagnostics: %.*s\n", + (int)diagnostics->getBufferSize(), + (const char*)diagnostics->getBufferPointer()); + } + + // Verify the loaded module is valid + SLANG_CHECK(loadedModule != nullptr); + } + + // Test 2: Test slang_loadModuleInfoFromIRBlob with valid IR data + { + SlangInt moduleVersion; + const char* moduleCompilerVersion; + const char* moduleName; + + SlangResult result = slang_loadModuleInfoFromIRBlob( + session, + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(result == SLANG_OK); + SLANG_CHECK(moduleName != nullptr); + SLANG_CHECK(strcmp(moduleName, "test_ir_module") == 0); + SLANG_CHECK(moduleCompilerVersion != nullptr); + SLANG_CHECK(moduleVersion >= 0); + } + + // Test 3: Test slang_loadModuleFromIRBlob with invalid parameters + { + ComPtr module; + ComPtr diagnostics; + + // Test with null session + module = slang_loadModuleFromIRBlob( + nullptr, + "testModule", + "test.slang", + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null moduleName + module = slang_loadModuleFromIRBlob( + session, + nullptr, + "test.slang", + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null path + module = slang_loadModuleFromIRBlob( + session, + "testModule", + nullptr, + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null source + module = slang_loadModuleFromIRBlob( + session, + "testModule", + "test.slang", + nullptr, + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with zero size + module = slang_loadModuleFromIRBlob( + session, + "testModule", + "test.slang", + irBlob->getBufferPointer(), + 0, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + } + + // Test 4: Test slang_loadModuleInfoFromIRBlob with invalid parameters + { + SlangInt moduleVersion; + const char* moduleCompilerVersion; + const char* moduleName; + + // Test with null session + SlangResult result = slang_loadModuleInfoFromIRBlob( + nullptr, + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(result == SLANG_E_INVALID_ARG); + + // Test with null source + result = slang_loadModuleInfoFromIRBlob( + session, + nullptr, + irBlob->getBufferSize(), + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(result == SLANG_E_INVALID_ARG); + + // Test with zero size + result = slang_loadModuleInfoFromIRBlob( + session, + irBlob->getBufferPointer(), + 0, + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(result == SLANG_E_INVALID_ARG); + } + + // Test 5: Test with corrupted/invalid IR data + { + ComPtr module; + ComPtr diagnostics; + + // Create some invalid data + const char* invalidData = "This is not valid IR data"; + size_t invalidDataSize = strlen(invalidData); + + module = slang_loadModuleFromIRBlob( + session, + "testModule", + "test.slang", + invalidData, + invalidDataSize, + diagnostics.writeRef()); + + // This might return nullptr or a module with diagnostics + if (module == nullptr) + { + // If it failed, that's expected for invalid data + SLANG_CHECK(true); + } + else + { + // If it succeeded, there should be diagnostics + if (diagnostics && diagnostics->getBufferSize() > 0) + { + SLANG_CHECK(true); + } + } + } + + // Test 6: Test slang_loadModuleInfoFromIRBlob with corrupted/invalid IR data + { + SlangInt moduleVersion; + const char* moduleCompilerVersion; + const char* moduleName; + + // Create some invalid data + const char* invalidData = "This is not valid IR data"; + size_t invalidDataSize = strlen(invalidData); + + SlangResult result = slang_loadModuleInfoFromIRBlob( + session, + invalidData, + invalidDataSize, + moduleVersion, + moduleCompilerVersion, + moduleName); + + // This should fail with invalid data + SLANG_CHECK(result != SLANG_OK); + } + + // Test 7: Test round-trip serialization and loading + { + // Load the module from IR + ComPtr loadedModule; + ComPtr diagnostics; + + loadedModule = slang_loadModuleFromIRBlob( + session, + "test_round_trip", + "test_round_trip.slang", + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(loadedModule != nullptr); + + if (loadedModule) + { + // Serialize the loaded module again + ComPtr roundTripBlob; + SLANG_CHECK(loadedModule->serialize(roundTripBlob.writeRef()) == SLANG_OK); + SLANG_CHECK(roundTripBlob != nullptr); + SLANG_CHECK(roundTripBlob->getBufferSize() > 0); + + // Load it again + ComPtr roundTripModule; + roundTripModule = slang_loadModuleFromIRBlob( + session, + "test_round_trip_2", + "test_round_trip_2.slang", + roundTripBlob->getBufferPointer(), + roundTripBlob->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(roundTripModule != nullptr); + } + } + + // Test 8: Test multiple modules with different IR data + { + // Create a second module with different content + const char* testModuleSource2 = R"( + module test_ir_module_2; + + public struct AnotherStruct { + int a, b, c; + } + + public void anotherFunction(AnotherStruct input) { + // Another function + } + )"; + + ComPtr module2; + ComPtr diagnostics2; + ComPtr irBlob2; + + module2 = session->loadModuleFromSourceString( + "test_ir_module_2", + "test_ir_module_2.slang", + testModuleSource2, + diagnostics2.writeRef()); + + SLANG_CHECK(module2 != nullptr); + SLANG_CHECK(module2->serialize(irBlob2.writeRef()) == SLANG_OK); + + // Load both modules + ComPtr loadedModule1; + ComPtr loadedModule2; + ComPtr diagnostics; + + loadedModule1 = slang_loadModuleFromIRBlob( + session, + "test_ir_module_1_loaded", + "test_ir_module_1_loaded.slang", + irBlob->getBufferPointer(), + irBlob->getBufferSize(), + diagnostics.writeRef()); + + loadedModule2 = slang_loadModuleFromIRBlob( + session, + "test_ir_module_2_loaded", + "test_ir_module_2_loaded.slang", + irBlob2->getBufferPointer(), + irBlob2->getBufferSize(), + diagnostics.writeRef()); + + SLANG_CHECK(loadedModule1 != nullptr); + SLANG_CHECK(loadedModule2 != nullptr); + + // Verify both modules loaded successfully + SLANG_CHECK(loadedModule1 != nullptr); + SLANG_CHECK(loadedModule2 != nullptr); + } +} diff --git a/tools/slang-unit-test/unit-test-load-module-from-source.cpp b/tools/slang-unit-test/unit-test-load-module-from-source.cpp new file mode 100644 index 000000000..41c32211b --- /dev/null +++ b/tools/slang-unit-test/unit-test-load-module-from-source.cpp @@ -0,0 +1,268 @@ +// unit-test-load-module-from-source.cpp + +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include + +using namespace Slang; + +// Test the loadModuleFromSource method and slang_loadModuleFromSource function +SLANG_UNIT_TEST(loadModuleFromSource) +{ + // Test source code with various content + const char* testSource = R"( + [shader("compute")] + [numthreads(1,1,1)] + void computeMain(uint3 workGroup : SV_GroupID) + { + // Simple compute shader + } + )"; + + size_t sourceSize = strlen(testSource); + + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + sessionDesc.targets = &targetDesc; + + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + // Test 1: Test the method version (existing functionality) + { + ComPtr module; + ComPtr diagnostics; + + // Use loadModuleFromSourceString which takes a string directly + module = session->loadModuleFromSourceString( + "testModule", + "test.slang", + testSource, + diagnostics.writeRef()); + + SLANG_CHECK(module != nullptr); + if (diagnostics) + { + // If there are diagnostics, they should be warnings or errors + SLANG_CHECK(diagnostics->getBufferSize() > 0); + } + } + + // Test 2: Test the new slang_loadModuleFromSource function + { + ComPtr module; + ComPtr diagnostics; + + module = slang_loadModuleFromSource( + session, + "testModule2", + "test2.slang", + testSource, + sourceSize, + diagnostics.writeRef()); + + SLANG_CHECK(module != nullptr); + if (diagnostics) + { + // If there are diagnostics, they should be warnings or errors + SLANG_CHECK(diagnostics->getBufferSize() > 0); + } + } + + // Test 3: Test with invalid parameters + { + ComPtr module; + ComPtr diagnostics; + + // Test with null session + module = slang_loadModuleFromSource( + nullptr, + "testModule", + "test.slang", + testSource, + sourceSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null moduleName + module = slang_loadModuleFromSource( + session, + nullptr, + "test.slang", + testSource, + sourceSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null path + module = slang_loadModuleFromSource( + session, + "testModule", + nullptr, + testSource, + sourceSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + } + + // Test 4: Test with null source and non-zero size (should fail) + { + ComPtr module; + ComPtr diagnostics; + + module = slang_loadModuleFromSource( + session, + "testModule", + "test.slang", + nullptr, + 10, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + } + + // Test 5: Test with complex source code + { + const char* complexSource = R"( + [shader("compute")] + [numthreads(8,8,1)] + void computeMain(uint3 workGroup : SV_GroupID, uint3 localID : SV_GroupThreadID) + { + uint2 pixelPos = workGroup.xy * uint2(8,8) + localID.xy; + + // Simple computation + float result = sin(pixelPos.x * 0.1f) * cos(pixelPos.y * 0.1f); + + // Store result (in a real shader, this would go to a buffer) + // outputBuffer[pixelPos] = result; + } + )"; + + size_t complexSourceSize = strlen(complexSource); + + ComPtr module; + ComPtr diagnostics; + + module = slang_loadModuleFromSource( + session, + "complexModule", + "complex.slang", + complexSource, + complexSourceSize, + diagnostics.writeRef()); + + SLANG_CHECK(module != nullptr); + } + + // Test 6: Test IR blob functions with invalid parameters + { + ComPtr module; + ComPtr diagnostics; + const char* testData = "test data"; + size_t testDataSize = strlen(testData); + + // Test with null session + module = slang_loadModuleFromIRBlob( + nullptr, + "testModule", + "test.slang", + testData, + testDataSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null moduleName + module = slang_loadModuleFromIRBlob( + session, + nullptr, + "test.slang", + testData, + testDataSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null path + module = slang_loadModuleFromIRBlob( + session, + "testModule", + nullptr, + testData, + testDataSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with null source + module = slang_loadModuleFromIRBlob( + session, + "testModule", + "test.slang", + nullptr, + testDataSize, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test with zero size + module = slang_loadModuleFromIRBlob( + session, + "testModule", + "test.slang", + testData, + 0, + diagnostics.writeRef()); + + SLANG_CHECK(module == nullptr); + + // Test loadModuleInfoFromIRBlob with null session + SlangInt moduleVersion; + const char* moduleCompilerVersion; + const char* moduleName; + + SlangResult infoResult = slang_loadModuleInfoFromIRBlob( + nullptr, + testData, + testDataSize, + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(infoResult == SLANG_E_INVALID_ARG); + + // Test loadModuleInfoFromIRBlob with null source + infoResult = slang_loadModuleInfoFromIRBlob( + session, + nullptr, + testDataSize, + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(infoResult == SLANG_E_INVALID_ARG); + + // Test loadModuleInfoFromIRBlob with zero size + infoResult = slang_loadModuleInfoFromIRBlob( + session, + testData, + 0, + moduleVersion, + moduleCompilerVersion, + moduleName); + + SLANG_CHECK(infoResult == SLANG_E_INVALID_ARG); + } +} -- cgit v1.2.3