1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
#include "gfx-test-util.h"
#include "tools/unit-test/slang-unit-test.h"
#include <slang-com-ptr.h>
using Slang::ComPtr;
namespace gfx_test
{
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob)
{
if (diagnosticsBlob != nullptr)
{
getTestReporter()->message(TestMessageType::Info, (const char*)diagnosticsBlob->getBufferPointer());
}
}
Slang::Result loadComputeProgram(
gfx::IDevice* device,
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
slang::ProgramLayout*& slangReflection)
{
Slang::ComPtr<slang::ISession> slangSession;
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
diagnoseIfNeeded(diagnosticsBlob);
if (!module)
return SLANG_FAIL;
ComPtr<slang::IEntryPoint> computeEntryPoint;
SLANG_RETURN_ON_FAIL(
module->findEntryPointByName(entryPointName, computeEntryPoint.writeRef()));
Slang::List<slang::IComponentType*> componentTypes;
componentTypes.add(module);
componentTypes.add(computeEntryPoint);
Slang::ComPtr<slang::IComponentType> composedProgram;
SlangResult result = slangSession->createCompositeComponentType(
componentTypes.getBuffer(),
componentTypes.getCount(),
composedProgram.writeRef(),
diagnosticsBlob.writeRef());
diagnoseIfNeeded(diagnosticsBlob);
SLANG_RETURN_ON_FAIL(result);
slangReflection = composedProgram->getLayout();
gfx::IShaderProgram::Desc programDesc = {};
programDesc.pipelineType = gfx::PipelineType::Compute;
programDesc.slangProgram = composedProgram.get();
auto shaderProgram = device->createProgram(programDesc);
outShaderProgram = shaderProgram;
return SLANG_OK;
}
void compareComputeResult(gfx::IDevice* device, gfx::IBufferResource* buffer, uint8_t* expectedResult, size_t expectedBufferSize)
{
// Read back the results.
ComPtr<ISlangBlob> resultBlob;
GFX_CHECK_CALL_ABORT(device->readBufferResource(
buffer, 0, expectedBufferSize, resultBlob.writeRef()));
if (resultBlob->getBufferSize() < expectedBufferSize)
{
getTestReporter()->addResult(TestResult::Fail);
return;
}
// Compare results.
auto result = reinterpret_cast<const uint8_t*>(resultBlob->getBufferPointer());
for (int i = 0; i < expectedBufferSize; i++)
{
if (expectedResult[i] != result[i])
{
getTestReporter()->addResult(TestResult::Fail);
return;
}
}
getTestReporter()->addResult(TestResult::Pass);
}
}
|