diff options
| -rw-r--r-- | build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj | 1 | ||||
| -rw-r--r-- | build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters | 3 | ||||
| -rw-r--r-- | docs/user-guide/08-compiling.md | 2 | ||||
| -rw-r--r-- | slang.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-compiler-options.cpp | 108 | ||||
| -rw-r--r-- | source/slang/slang-compiler-options.h | 23 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops-debug-info-ext.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 48 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 7 | ||||
| -rw-r--r-- | tools/gfx-unit-test/compute-trivial.cpp | 99 |
15 files changed, 366 insertions, 24 deletions
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj index 3b3df0754..955c702d9 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj @@ -291,6 +291,7 @@ <ClCompile Include="..\..\..\tools\gfx-unit-test\buffer-barrier-test.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\clear-texture-test.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\compute-smoke.cpp" />
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\compute-trivial.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\copy-texture-tests.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\create-buffer-from-handle.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\existing-device-handle-test.cpp" />
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters index 9b98ce1c8..6d48a99ae 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters @@ -29,6 +29,9 @@ <ClCompile Include="..\..\..\tools\gfx-unit-test\compute-smoke.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\compute-trivial.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\gfx-unit-test\copy-texture-tests.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/docs/user-guide/08-compiling.md b/docs/user-guide/08-compiling.md index a00b7670e..358246b5a 100644 --- a/docs/user-guide/08-compiling.md +++ b/docs/user-guide/08-compiling.md @@ -657,7 +657,7 @@ meanings of their `CompilerOptionValue` encodings. | LineDirectiveMode | Specifies the line directive mode to use the generated textual code such as HLSL or CUDA. `intValue0` encodes an value defined in the `SlangLineDirectiveMode` enum. | | Optimization | Specifies the optimization level. `intValue0` encodes the value for the setting defined in the `SlangOptimizationLevel` enum. | | Obfuscate | Specifies whether or not to turn on obfuscation. When obfuscation is on, Slang will strip variable and function names from the target code and replace them with hash values. `intValue0` encodes a bool value for the setting. | -| VulkanBindShift | Specifies the `-fvk-bind-shift` option. `intValue0` (lower 8 bits): kind, `intValue0` (higher bits): set; `intValue1`: shift. | +| VulkanBindShift | Specifies the `-fvk-bind-shift` option. `intValue0` (higher 8 bits): kind, `intValue0` (lower bits): set; `intValue1`: shift. | | VulkanBindGlobals | Specifies the `-fvk-bind-globals` option. `intValue0`: index, `intValue`: set. | | VulkanInvertY | Specifies the `-fvk-invert-y` option. `intValue0` specifies a bool value for the setting. | | VulkanUseEntryPointName | When set, will keep the original name of entrypoints as they are defined in the source instead of renaming them to `main`. `intValue0` specifies a bool value for the setting. | @@ -871,7 +871,7 @@ extern "C" Optimization, // intValue0: OptimizationLevel Obfuscate, // bool - VulkanBindShift, // intValue0 (lower 8 bits): kind; intValue0(higher bits): set; intValue1: shift + VulkanBindShift, // intValue0 (higher 8 bits): kind; intValue0(lower bits): set; intValue1: shift VulkanBindGlobals, // intValue0: index; intValue1: set VulkanInvertY, // bool VulkanUseEntryPointName, // bool diff --git a/source/slang/slang-compiler-options.cpp b/source/slang/slang-compiler-options.cpp index aca5fb8db..100a5719f 100644 --- a/source/slang/slang-compiler-options.cpp +++ b/source/slang/slang-compiler-options.cpp @@ -20,6 +20,114 @@ namespace Slang } } + void CompilerOptionSet::writeCommandLineArgs(Session* globalSession, StringBuilder& sb) + { + for (auto& option : options) + { + auto optionInfoIndex = globalSession->m_commandOptions.findOptionByUserValue(CommandOptions::UserValue(option.key)); + if (optionInfoIndex == -1) + continue; + auto optionInfo = globalSession->m_commandOptions.getOptionAt(optionInfoIndex); + auto nameCommaIndex = optionInfo.names.indexOf(','); + if (nameCommaIndex == -1) nameCommaIndex = optionInfo.names.getLength(); + auto name = optionInfo.names.head(nameCommaIndex); + switch (option.key) + { + case CompilerOptionName::Capability: + for (auto v : option.value) + { + sb << " " << optionInfo.names << " " << v.stringValue; + } + break; + case CompilerOptionName::Include: + for (auto v : option.value) + { + sb << " -I \"" << v.stringValue << "\""; + } + break; + case CompilerOptionName::MacroDefine: + for (auto v : option.value) + { + sb << " -D" << v.stringValue; + if (v.stringValue2.getLength()) + sb << "=" << v.stringValue2; + } + break; + case CompilerOptionName::VulkanBindShift: // intValue0 (higher 8 bits): kind; intValue0(higher bits): set; intValue1: shift + for (auto v : option.value) + { + uint8_t kind; + int set, shift; + v.unpackInt3(kind, set, shift); + switch ((HLSLToVulkanLayoutOptions::Kind)(kind)) + { + case HLSLToVulkanLayoutOptions::Kind::UnorderedAccess: + sb << " -fvk-u-shift"; + break; + case HLSLToVulkanLayoutOptions::Kind::Sampler: + sb << " -fvk-s-shift"; + break; + case HLSLToVulkanLayoutOptions::Kind::ShaderResource: + sb << " -fvk-t-shift"; + break; + case HLSLToVulkanLayoutOptions::Kind::ConstantBuffer: + sb << " -fvk-b-shift"; + break; + default: + continue; + } + sb << " " << shift << " " << set; + } + break; + case CompilerOptionName::VulkanBindShiftAll: // intValue0: set; intValue1: shift + for (auto v : option.value) + { + sb << " -fvk-all-shift " << v.intValue2 << " " << v.intValue; + } + break; + case CompilerOptionName::VulkanBindGlobals: // intValue0: index; intValue1: set + for (auto v : option.value) + { + sb << " " << name << v.intValue << " " << v.intValue2; + } + break; + case CompilerOptionName::Optimization: + for (auto v : option.value) + { + sb << " -O" << v.intValue; + } + break; + case CompilerOptionName::DownstreamArgs: + for (auto v : option.value) + { + List<UnownedStringSlice> lines; + StringUtil::split(v.stringValue2.getUnownedSlice(), '\n', lines); + for (auto l : lines) + { + sb << " -x" << v.stringValue << " " << l.trim(); + } + } + break; + case CompilerOptionName::EmitSpirvDirectly: + case CompilerOptionName::GLSLForceScalarLayout: + case CompilerOptionName::MatrixLayoutRow: + case CompilerOptionName::MatrixLayoutColumn: + case CompilerOptionName::VulkanInvertY: + case CompilerOptionName::VulkanUseEntryPointName: + case CompilerOptionName::VulkanUseGLLayout: + case CompilerOptionName::VulkanEmitReflection: + case CompilerOptionName::EnableEffectAnnotations: + case CompilerOptionName::DefaultImageFormatUnknown: + case CompilerOptionName::DisableDynamicDispatch: + case CompilerOptionName::DisableSpecialization: + case CompilerOptionName::DumpIntermediates: + if (option.value.getCount() && option.value[0].intValue != 0) + sb << " " << name; + break; + } + } + } + void CompilerOptionSet::buildHash(DigestBuilder<SHA1>& builder) { for (auto& kv : options) diff --git a/source/slang/slang-compiler-options.h b/source/slang/slang-compiler-options.h index 0b35f9a29..ea2bbf070 100644 --- a/source/slang/slang-compiler-options.h +++ b/source/slang/slang-compiler-options.h @@ -83,6 +83,8 @@ namespace Slang List<slang::CompilerOptionEntry> entries; List<String> stringPool; }; + + class Session; struct CompilerOptionSet { @@ -92,6 +94,8 @@ namespace Slang static bool allowDuplicate(CompilerOptionName name); + void writeCommandLineArgs(Session* globalSession, StringBuilder& sb); + OrderedDictionary<CompilerOptionName, List<CompilerOptionValue>> options; bool hasOption(CompilerOptionName name) @@ -137,18 +141,13 @@ namespace Slang { for (auto element : value) { - Index index = -1; - // We don't deduplicate downstream args. - if (name != CompilerOptionName::DownstreamArgs) - { - v->findFirstIndex([&](const CompilerOptionValue& existingVal) - { - if (existingVal.kind == CompilerOptionValueKind::Int) - return existingVal.intValue == element.intValue; - else - return existingVal.stringValue == element.stringValue; - }); - } + Index index = v->findFirstIndex([&](const CompilerOptionValue& existingVal) + { + if (existingVal.kind == CompilerOptionValueKind::Int) + return existingVal.intValue == element.intValue; + else + return existingVal.stringValue == element.stringValue; + }); if (index != -1) { if (replaceDuplicate) diff --git a/source/slang/slang-emit-spirv-ops-debug-info-ext.h b/source/slang/slang-emit-spirv-ops-debug-info-ext.h index 3d2a10aab..fcf931f0a 100644 --- a/source/slang/slang-emit-spirv-ops-debug-info-ext.h +++ b/source/slang/slang-emit-spirv-ops-debug-info-ext.h @@ -24,6 +24,11 @@ SpvInst* emitOpDebugLine(SpvInstParent* parent, IRInst* inst, const T& idResultT return emitInst(parent, inst, SpvOpExtInst, idResultType, kResultID, set, SpvWord(103), source, lineStart, lineEnd, colStart, colEnd); } +SpvInst* emitOpDebugEntryPoint(SpvInstParent* parent, IRInst* resultType, SpvInst* set, SpvInst* entryPoint, SpvInst* scope, IRInst* compiler, IRInst* args) +{ + return emitInst(parent, nullptr, SpvOpExtInst, resultType, kResultID, set, SpvWord(107), entryPoint, scope, compiler, args); +} + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/nonsemantic/NonSemantic.Shader.DebugInfo.100.asciidoc#DebugFunction template<typename T> SpvInst* emitOpDebugFunction(SpvInstParent* parent, IRInst* inst, const T& idResultType, SpvInst* set, IRInst* name, SpvInst* type, IRInst* source, IRInst* lineStart, IRInst* colStart, SpvInst* scope, IRInst* linkageName, IRInst* flag, IRInst* scopeLine) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index cf709c438..7cb263b61 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1389,11 +1389,20 @@ struct SPIRVEmitContext bool useForwardDeclaration = (!m_mapIRInstToSpvInst.containsKey(valueType) && as<IRStructType>(valueType) && storageClass == SpvStorageClassPhysicalStorageBuffer); + SpvId valueTypeId; + if (useForwardDeclaration) + { + valueTypeId = getIRInstSpvID(valueType); + } + else + { + auto spvValueType = ensureInst(valueType); + valueTypeId = getID(spvValueType); + } auto resultSpvType = emitOpTypePointer( inst, storageClass, - useForwardDeclaration? getIRInstSpvID(valueType) : getID(ensureInst(valueType)) - ); + valueTypeId); if (useForwardDeclaration) { // After everything has been emitted, we will move the pointer definition to the end @@ -1613,7 +1622,7 @@ struct SPIRVEmitContext emitIntConstant(100, builder.getUIntType()), // ExtDebugInfo version. emitIntConstant(5, builder.getUIntType()), // DWARF version. result, - emitIntConstant(6, builder.getUIntType())); // Language, use HLSL's ID for now. + emitIntConstant(SpvSourceLanguageSlang, builder.getUIntType())); // Language. registerDebugInst(moduleInst, translationUnit); } return result; @@ -1625,6 +1634,9 @@ struct SPIRVEmitContext case kIROp_HLSLTriangleStreamType: case kIROp_HLSLLineStreamType: case kIROp_HLSLPointStreamType: + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: return nullptr; default: { @@ -2117,6 +2129,20 @@ struct SPIRVEmitContext return varInst; } + String getDebugInfoCommandLineArgumentForEntryPoint(IREntryPointDecoration* entryPointDecor) + { + StringBuilder sb; + sb << "-target spirv "; + m_targetProgram->getOptionSet().writeCommandLineArgs(m_targetProgram->getTargetReq()->getSession(), sb); + sb << " -stage " << getStageName(entryPointDecor->getProfile().getStage()); + if (auto entryPointName = as<IRStringLit>(getName(entryPointDecor->getParent()))) + { + sb << " -entry " << entryPointName->getStringSlice(); + } + sb << " -g2"; + return sb.produceString(); + } + /// Emit the given `irFunc` to SPIR-V SpvInst* emitFunc(IRFunc* irFunc) { @@ -2250,6 +2276,22 @@ struct SPIRVEmitContext if (funcDebugScope) { + if (auto entryPointDecor = irFunc->findDecoration<IREntryPointDecoration>()) + { + if (auto debugScope = findDebugScope(irFunc->getModule()->getModuleInst())) + { + IRBuilder builder(irFunc); + String cmdArgs = getDebugInfoCommandLineArgumentForEntryPoint(entryPointDecor); + emitOpDebugEntryPoint( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + m_voidType, + getNonSemanticDebugInfoExtInst(), + funcDebugScope, + debugScope, + builder.getStringValue(toSlice("slangc")), + builder.getStringValue(cmdArgs.getUnownedSlice())); + } + } emitOpDebugScope(spvBlock, nullptr, m_voidType, getNonSemanticDebugInfoExtInst(), funcDebugScope); } diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index e164bd2c6..fb914f7c0 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -880,14 +880,12 @@ static LegalVal legalizeDebugValue(IRTypeLegalizationContext* context, LegalVal } case LegalType::Flavor::tuple: { - auto tupleVar = debugVar.getTuple(); - UInt index = 0; - for (auto ee : tupleVar->elements) + auto tupleVal = debugValue.getTuple(); + for (auto ee : tupleVal->elements) { - auto innerResult = legalizeDebugValue(context, debugVar, debugValue.getTuple()->elements[index].val, originalInst); + auto innerResult = legalizeDebugValue(context, debugVar, ee.val, originalInst); if (innerResult.flavor != LegalVal::Flavor::none) return innerResult; - index++; } return LegalVal(); } diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index d8327a580..d7c3c6bda 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -1037,6 +1037,40 @@ struct PeepholeContext : InstPassBase } break; } + case kIROp_Load: + { + // Load from undef is undef. + if (as<IRLoad>(inst)->getPtr()->getOp() == kIROp_undefined) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto undef = builder.emitUndefined(inst->getDataType()); + inst->replaceUsesWith(undef); + maybeRemoveOldInst(inst); + changed = true; + } + break; + } + case kIROp_Store: + { + // Store undef is no-op. + if (as<IRStore>(inst)->getVal()->getOp() == kIROp_undefined) + { + maybeRemoveOldInst(inst); + changed = true; + } + break; + } + case kIROp_DebugValue: + { + // Update debug value with undef is no-op. + if (as<IRDebugValue>(inst)->getValue()->getOp() == kIROp_undefined) + { + maybeRemoveOldInst(inst); + changed = true; + } + break; + } default: break; } diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index f2979cb79..8652127da 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1694,6 +1694,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + List<IRInst*> m_instsToRemove; void processWorkList() { while (workList.getCount() != 0) @@ -1806,6 +1807,17 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_SPIRVAsm: processSPIRVAsm(as<IRSPIRVAsm>(inst)); break; + case kIROp_DebugValue: + if (!isSimpleDataType(as<IRDebugValue>(inst)->getDebugVar()->getDataType())) + inst->removeAndDeallocate(); + break; + case kIROp_DebugVar: + if (!isSimpleDataType(as<IRDebugVar>(inst)->getDataType())) + { + inst->removeFromParent(); + m_instsToRemove.add(inst); + } + break; case kIROp_Func: eliminateContinueBlocksInFunc(m_module, as<IRFunc>(inst)); [[fallthrough]]; @@ -1853,6 +1865,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } processWorkList(); + for (auto inst : m_instsToRemove) + inst->removeAndDeallocate(); + // Translate types. List<IRHLSLStructuredBufferTypeBase*> instsToProcess; List<IRInst*> textureFootprintTypes; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 8b39e8b45..5b29d23a8 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -196,6 +196,39 @@ bool isValueType(IRInst* dataType) } } +bool isSimpleDataType(IRType* type) +{ + type = (IRType*)unwrapAttributedType(type); + if (as<IRBasicType>(type)) + return true; + switch (type->getOp()) + { + case kIROp_StructType: + { + auto structType = as<IRStructType>(type); + for (auto field : structType->getFields()) + { + if (!isSimpleDataType(field->getFieldType())) + return false; + } + return true; + break; + } + case kIROp_Param: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_InterfaceType: + case kIROp_AnyValueType: + return true; + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + case kIROp_PtrType: + return isSimpleDataType((IRType*)type->getOperand(0)); + default: + return false; + } +} + IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue) { auto outerGeneric = as<IRGeneric>(findOuterGeneric(value)); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 648ba3531..fd34d81f7 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -87,6 +87,8 @@ inline bool isScalarIntegerType(IRType* type) // No side effect can take place through a value of a "Value" type. bool isValueType(IRInst* type); +bool isSimpleDataType(IRType* type); + inline bool isChildInstOf(IRInst* inst, IRInst* parent) { while (inst) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e4e324def..17ce04fec 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6282,7 +6282,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> } }; -void maybeEmitDebugLine(IRGenContext* context, Stmt* stmt) +void maybeEmitDebugLine(IRGenContext* context, StmtLoweringVisitor& visitor, Stmt* stmt) { if (!context->includeDebugInfo) return; @@ -6296,6 +6296,7 @@ void maybeEmitDebugLine(IRGenContext* context, Stmt* stmt) if (context->shared->mapSourceFileToDebugSourceInst.tryGetValue(source, debugSourceInst)) { auto humaneLoc = context->getLinkage()->getSourceManager()->getHumaneLoc(stmt->loc, SourceLocType::Emit); + visitor.startBlockIfNeeded(stmt); context->irBuilder->emitDebugLine(debugSourceInst, humaneLoc.line, humaneLoc.line, humaneLoc.column, humaneLoc.column + 1); } } @@ -6327,7 +6328,7 @@ void lowerStmt( try { - maybeEmitDebugLine(context, stmt); + maybeEmitDebugLine(context, visitor, stmt); visitor.dispatch(stmt); } // Don't emit any context message for an explicit `AbortCompilationException` @@ -7436,6 +7437,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, irParam, decl); maybeSetRate(context, irParam, decl); addVarDecorations(context, irParam, decl); + maybeAddDebugLocationDecoration(context, irParam); if (decl) { @@ -7591,6 +7593,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> maybeSetRate(context, irGlobal, decl); addVarDecorations(context, irGlobal, decl); + maybeAddDebugLocationDecoration(context, irGlobal); if (decl) { diff --git a/tools/gfx-unit-test/compute-trivial.cpp b/tools/gfx-unit-test/compute-trivial.cpp new file mode 100644 index 000000000..5674dd1a7 --- /dev/null +++ b/tools/gfx-unit-test/compute-trivial.cpp @@ -0,0 +1,99 @@ +#include "tools/unit-test/slang-unit-test.h" + +#include "slang-gfx.h" +#include "gfx-test-util.h" +#include "tools/gfx-util/shader-cursor.h" +#include "source/core/slang-basic.h" + +using namespace gfx; + +namespace gfx_test +{ + void computeTrivialTestImpl(IDevice* device, UnitTestContext* context) + { + Slang::ComPtr<ITransientResourceHeap> transientHeap; + ITransientResourceHeap::Desc transientHeapDesc = {}; + transientHeapDesc.constantBufferSize = 4096; + GFX_CHECK_CALL_ABORT( + device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef())); + + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "compute-trivial", "computeMain", slangReflection)); + + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<gfx::IPipelineState> pipelineState; + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + float initialData[] = { 0.0f, 1.0f, 2.0f, 3.0f }; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = numberCount * sizeof(float); + bufferDesc.format = gfx::Format::Unknown; + bufferDesc.elementSize = sizeof(float); + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::ShaderResource, + ResourceState::UnorderedAccess, + ResourceState::CopyDestination, + ResourceState::CopySource); + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBufferResource> numbersBuffer; + GFX_CHECK_CALL_ABORT(device->createBufferResource( + bufferDesc, + (void*)initialData, + numbersBuffer.writeRef())); + + ComPtr<IResourceView> bufferView; + IResourceView::Desc viewDesc = {}; + viewDesc.type = IResourceView::Type::UnorderedAccess; + viewDesc.format = Format::Unknown; + GFX_CHECK_CALL_ABORT( + device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef())); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics }; + auto queue = device->createCommandQueue(queueDesc); + + auto commandBuffer = transientHeap->createCommandBuffer(); + auto encoder = commandBuffer->encodeComputeCommands(); + + auto rootObject = encoder->bindPipeline(pipelineState); + + // Bind buffer view to the entry point. + ShaderCursor(rootObject).getPath("buffer").setResource(bufferView); + + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + commandBuffer->close(); + queue->executeCommandBuffer(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult( + device, + numbersBuffer, + Slang::makeArray<float>(1.0f, 2.0f, 3.0f, 4.0f)); + } + + SLANG_UNIT_TEST(computeTrivialD3D12) + { + runTestImpl(computeTrivialTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + } + + SLANG_UNIT_TEST(computeTrivialD3D11) + { + runTestImpl(computeTrivialTestImpl, unitTestContext, Slang::RenderApiFlag::D3D11); + } + + SLANG_UNIT_TEST(computeTrivialVulkan) + { + runTestImpl(computeTrivialTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + } + +} |
