From bd6dbaf7c3ea720b4ed39904fe08878f9dcbd947 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 21 Aug 2023 17:07:34 -0700 Subject: Compile append and consume structured buffers to glsl. (#3142) * Compile append and consume structured buffers to glsl. * Fix. * Update CI config. --------- Co-authored-by: Yong He --- source/slang/hlsl.meta.slang | 26 ++- source/slang/slang-emit-c-like.cpp | 92 +++++++- source/slang/slang-emit-c-like.h | 5 +- source/slang/slang-emit-glsl.cpp | 38 +++- source/slang/slang-emit-glsl.h | 2 + source/slang/slang-emit-torch.cpp | 6 +- source/slang/slang-emit-torch.h | 2 +- source/slang/slang-emit.cpp | 9 + source/slang/slang-ir-byte-address-legalize.cpp | 2 + source/slang/slang-ir-inst-defs.h | 8 + source/slang/slang-ir-insts.h | 19 ++ ...g-ir-lower-append-consume-structured-buffer.cpp | 247 +++++++++++++++++++++ ...ang-ir-lower-append-consume-structured-buffer.h | 17 ++ source/slang/slang-type-layout.cpp | 21 +- source/slang/slang-type-layout.h | 2 + 15 files changed, 480 insertions(+), 16 deletions(-) create mode 100644 source/slang/slang-ir-lower-append-consume-structured-buffer.cpp create mode 100644 source/slang/slang-ir-lower-append-consume-structured-buffer.h (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index b690a5910..3dcdc6c54 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6,16 +6,31 @@ typedef uint UINT; [ForceInline] float3 __asFloat3(float2 v) { return float3(v, 0); } [ForceInline] float3 __asFloat3(float3 v) { return v; } +__generic +__intrinsic_op($(kIROp_StructuredBufferGetDimensions)) +uint2 __structuredBufferGetDimensions(AppendStructuredBuffer buffer); + +__generic +__intrinsic_op($(kIROp_StructuredBufferGetDimensions)) +uint2 __structuredBufferGetDimensions(ConsumeStructuredBuffer buffer); + __generic __magic_type(HLSLAppendStructuredBufferType) __intrinsic_type($(kIROp_HLSLAppendStructuredBufferType)) struct AppendStructuredBuffer { + __intrinsic_op($(kIROp_StructuredBufferAppend)) void Append(T value); + [ForceInline] void GetDimensions( out uint numStructs, - out uint stride); + out uint stride) + { + let result = __structuredBufferGetDimensions(this); + numStructs = result.x; + stride = result.y; + } }; __magic_type(HLSLByteAddressBufferType) @@ -257,11 +272,18 @@ __magic_type(HLSLConsumeStructuredBufferType) __intrinsic_type($(kIROp_HLSLConsumeStructuredBufferType)) struct ConsumeStructuredBuffer { + __intrinsic_op($(kIROp_StructuredBufferConsume)) T Consume(); + [ForceInline] void GetDimensions( out uint numStructs, - out uint stride); + out uint stride) + { + let result = __structuredBufferGetDimensions(this); + numStructs = result.x; + stride = result.y; + } }; __generic diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index e1f631283..75a15d0c9 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -455,6 +455,66 @@ void CLikeSourceEmitter::emitRTTIObject(IRRTTIObject* rttiObject) // This is only used in targets that support dynamic dispatching. } +void CLikeSourceEmitter::defaultEmitInstStmt(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_AtomicCounterIncrement: + { + auto oldValName = getName(inst); + m_writer->emit("int "); + m_writer->emit(oldValName); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1, "); + m_writer->emit(oldValName); + m_writer->emit(");\n"); + } + break; + case kIROp_AtomicCounterDecrement: + { + auto oldValName = getName(inst); + m_writer->emit("int "); + m_writer->emit(oldValName); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -1, "); + m_writer->emit(oldValName); + m_writer->emit(");\n"); + } + break; + case kIROp_StructuredBufferGetDimensions: + { + auto count = _generateUniqueName(UnownedStringSlice("_elementCount")); + auto stride = _generateUniqueName(UnownedStringSlice("_stride")); + + m_writer->emit("uint "); + m_writer->emit(count); + m_writer->emit(";\n"); + m_writer->emit("uint "); + m_writer->emit(stride); + m_writer->emit(";\n"); + emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::Postfix))); + m_writer->emit(".GetDimensions("); + m_writer->emit(count); + m_writer->emit(", "); + m_writer->emit(stride); + m_writer->emit(");\n"); + emitInstResultDecl(inst); + m_writer->emit("uint2("); + m_writer->emit(count); + m_writer->emit(", "); + m_writer->emit(stride); + m_writer->emit(");\n"); + } + break; + default: + diagnoseUnhandledInst(inst); + } +} + void CLikeSourceEmitter::emitTypeImpl(IRType* type, const StringSliceLoc* nameAndLoc) { @@ -1874,6 +1934,16 @@ void CLikeSourceEmitter::emitInstExpr(IRInst* inst, const EmitOpInfo& inOuterPre defaultEmitInstExpr(inst, inOuterPrec); } +void CLikeSourceEmitter::emitInstStmt(IRInst* inst) +{ + // Try target specific impl first + if (tryEmitInstStmtImpl(inst)) + { + return; + } + defaultEmitInstStmt(inst); +} + void CLikeSourceEmitter::diagnoseUnhandledInst(IRInst* inst) { getSink()->diagnose(inst, Diagnostics::unimplemented, "unexpected IR opcode during code emit"); @@ -2193,6 +2263,23 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO } break; + case kIROp_StructuredBufferAppend: + { + auto outer = getInfo(EmitOp::General); + emitOperand(inst->getOperand(0), leftSide(outer, getInfo(EmitOp::Postfix))); + m_writer->emit(".Append("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + break; + case kIROp_StructuredBufferConsume: + { + auto outer = getInfo(EmitOp::General); + emitOperand(inst->getOperand(0), leftSide(outer, getInfo(EmitOp::Postfix))); + m_writer->emit(".Consume()"); + } + break; + case kIROp_Call: { emitCallExpr((IRCall*)inst, outerPrec); @@ -2562,7 +2649,10 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) // Insts that needs to be emitted as code blocks. case kIROp_CudaKernelLaunch: - emitInstStmtImpl(inst); + case kIROp_AtomicCounterIncrement: + case kIROp_AtomicCounterDecrement: + case kIROp_StructuredBufferGetDimensions: + emitInstStmt(inst); break; case kIROp_LiveRangeStart: diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 4f8d23a0d..420132a5d 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -549,7 +549,10 @@ public: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) { SLANG_UNUSED(varDecl); SLANG_UNUSED(varType); return false; } virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { SLANG_UNUSED(inst); SLANG_UNUSED(inOuterPrec); return false; } - virtual void emitInstStmtImpl(IRInst* inst) { SLANG_UNUSED(inst); } + virtual bool tryEmitInstStmtImpl(IRInst* inst) { SLANG_UNUSED(inst); return false; } + + void defaultEmitInstStmt(IRInst* inst); + void emitInstStmt(IRInst* inst); virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) { SLANG_UNUSED(inst); } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 0920c236c..e1f74f70d 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -201,8 +201,11 @@ void GLSLSourceEmitter::_emitGLSLStructuredBuffer(IRGlobalParam* varDecl, IRHLSL m_writer->emit("buffer "); // Generate a dummy name for the block - m_writer->emit("_S"); - m_writer->emit(m_uniqueIDCounter++); + StringBuilder blockTypeName; + blockTypeName << "StructuredBuffer_"; + getTypeNameHint(blockTypeName, structuredBufferType->getElementType()); + blockTypeName << "_t"; + m_writer->emit(_generateUniqueName(blockTypeName.produceString().getUnownedSlice())); m_writer->emit(" {\n"); m_writer->indent(); @@ -2007,6 +2010,37 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return false; } +bool GLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_AtomicCounterIncrement: + { + auto oldValName = getName(inst); + m_writer->emit("int "); + m_writer->emit(oldValName); + m_writer->emit(" = "); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1);\n"); + return true; + } + case kIROp_AtomicCounterDecrement: + { + auto oldValName = getName(inst); + m_writer->emit("int "); + m_writer->emit(oldValName); + m_writer->emit(" = "); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -1);\n"); + return true; + } + default: + return false; + } +} + void GLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) { // Does this function declare any requirements on GLSL version or diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index d0cabfa94..7c1a15315 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -48,6 +48,8 @@ protected: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; + virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void emitGlobalInstImpl(IRInst* inst) override; void emitBufferPointerTypeDefinition(IRInst* ptrType); diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index bdb650607..ef04f33ba 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -65,12 +65,12 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) } } -void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst) +bool TorchCppSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) { switch (inst->getOp()) { default: - return; + return false; case kIROp_CudaKernelLaunch: { m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel("); @@ -101,7 +101,7 @@ void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst) emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); m_writer->emit(")));\n"); - break; + return true; } } } diff --git a/source/slang/slang-emit-torch.h b/source/slang/slang-emit-torch.h index aeb9058a4..9e76e42d1 100644 --- a/source/slang/slang-emit-torch.h +++ b/source/slang/slang-emit-torch.h @@ -19,7 +19,7 @@ public: protected: // CPPSourceEmitter overrides - virtual void emitInstStmtImpl(IRInst* inst) override; + virtual bool tryEmitInstStmtImpl(IRInst* inst) override; virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) override; virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) override; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 6521b05ba..03d62b540 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -31,6 +31,7 @@ #include "slang-ir-legalize-varying-params.h" #include "slang-ir-link.h" #include "slang-ir-com-interface.h" +#include "slang-ir-lower-append-consume-structured-buffer.h" #include "slang-ir-lower-binding-query.h" #include "slang-ir-lower-generics.h" #include "slang-ir-lower-tuple-types.h" @@ -494,6 +495,14 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); + // On non-HLSL targets, there isn't an implementation of `AppendStructuredBuffer` + // and `ConsumeStructuredBuffer` types, so we lower them into normal struct types + // of `RWStructuredBuffer` typed fields now. + if (target != CodeGenTarget::HLSL) + { + lowerAppendConsumeStructuredBuffers(targetRequest, irModule, sink); + } + // We don't need the legalize pass for C/C++ based types if(options.shouldLegalizeExistentialAndResourceTypes ) { diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp index 40fe64693..b4de66d77 100644 --- a/source/slang/slang-ir-byte-address-legalize.cpp +++ b/source/slang/slang-ir-byte-address-legalize.cpp @@ -741,6 +741,8 @@ struct ByteAddressBufferLegalizationContext paramBuilder.setInsertBefore(byteAddressBufferParam); auto structuredBufferParam = paramBuilder.createGlobalParam(structuredBufferParamType); + if (auto nameHint = byteAddressBufferParam->findDecoration()) + paramBuilder.addNameHintDecoration(structuredBufferParam, nameHint->getName()); // The new parameter needs to be given a layout to match the existing // parameter, so that it is given the same `binding` in the generated code. diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 22355bd7e..c1b021181 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -447,6 +447,14 @@ INST(RWStructuredBufferStore, rwstructuredBufferStore, 3, 0) INST(RWStructuredBufferGetElementPtr, rwstructuredBufferGetElementPtr, 2, 0) +// Append/Consume-StructuredBuffer operations +INST(StructuredBufferAppend, StructuredBufferAppend, 1, 0) +INST(StructuredBufferConsume, StructuredBufferConsume, 1, 0) +INST(StructuredBufferGetDimensions, StructuredBufferGetDimensions, 1, 0) + +INST(AtomicCounterIncrement, AtomicCounterIncrement, 1, 0) +INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0) + INST(MeshOutputRef, meshOutputRef, 2, 0) // Construct a vector from a scalar diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index adfcac7fd..4b0cac182 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2144,6 +2144,25 @@ struct IRRWStructuredBufferGetElementPtr : IRInst IRInst* getIndex() { return getOperand(1); } }; +struct IRStructuredBufferAppend : IRInst +{ + IR_LEAF_ISA(StructuredBufferAppend); + IRInst* getBuffer() { return getOperand(0); } + IRInst* getElement() { return getOperand(1); } +}; + +struct IRStructuredBufferConsume : IRInst +{ + IR_LEAF_ISA(StructuredBufferConsume); + IRInst* getBuffer() { return getOperand(0); } +}; + +struct IRStructuredBufferGetDimensions : IRInst +{ + IR_LEAF_ISA(StructuredBufferGetDimensions); + IRInst* getBuffer() { return getOperand(0); } +}; + struct IRLoadReverseGradient : IRInst { IR_LEAF_ISA(LoadReverseGradient) diff --git a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp new file mode 100644 index 000000000..fa9f16223 --- /dev/null +++ b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp @@ -0,0 +1,247 @@ +#include "slang-ir-lower-append-consume-structured-buffer.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-layout.h" +#include "slang-ir-lower-buffer-element-type.h" + +namespace Slang +{ + static void lowerStructuredBufferType(TargetRequest* target, IRHLSLStructuredBufferTypeBase* type) + { + IRBuilder builder(type); + builder.setInsertBefore(type); + + auto elementType = type->getElementType(); + + // Type. + auto structType = builder.createStructType(); + StringBuilder nameSb; + if (type->getOp() == kIROp_HLSLAppendStructuredBufferType) + nameSb << "AppendStructuredBuffer_"; + else + nameSb << "ConsumeStructuredBuffer_"; + getTypeNameHint(nameSb, elementType); + nameSb << "_t"; + builder.addNameHintDecoration(structType, nameSb.produceString().getUnownedSlice()); + + auto elementBufferKey = builder.createStructKey(); + builder.addNameHintDecoration(elementBufferKey, UnownedStringSlice("elements")); + + auto counterBufferKey = builder.createStructKey(); + builder.addNameHintDecoration(counterBufferKey, UnownedStringSlice("counter")); + + auto elementBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, elementType); + auto counterBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, builder.getIntType()); + + builder.createStructField(structType, elementBufferKey, elementBufferType); + builder.createStructField(structType, counterBufferKey, counterBufferType); + + // Type layout. + auto layoutRules = getTypeLayoutRuleForBuffer(target, type); + + IRTypeLayout::Builder elementTypeLayoutBuilder(&builder); + IRSizeAndAlignment elementSize; + getSizeAndAlignment(layoutRules, elementType, &elementSize); + elementTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize((LayoutSize::RawValue)elementSize.getStride())); + auto elementTypeLayout = elementTypeLayoutBuilder.build(); + + IRStructuredBufferTypeLayout::Builder elementBufferTypeLayoutBuilder(&builder, elementTypeLayout); + elementBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1); + auto elementBufferTypeLayout = elementBufferTypeLayoutBuilder.build(); + + IRTypeLayout::Builder counterTypeLayoutBuilder(&builder); + counterTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize(4)); + auto counterTypeLayout = counterTypeLayoutBuilder.build(); + + IRStructuredBufferTypeLayout::Builder counterBufferTypeLayoutBuilder(&builder, counterTypeLayout); + counterBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1); + auto counterBufferTypeLayout = counterBufferTypeLayoutBuilder.build(); + + IRVarLayout::Builder elementBufferVarLayoutBuilder(&builder, elementBufferTypeLayout); + elementBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 0; + auto elementBufferVarLayout = elementBufferVarLayoutBuilder.build(); + + IRVarLayout::Builder counterBufferVarLayoutBuilder(&builder, counterBufferTypeLayout); + counterBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 1; + auto counterBufferVarLayout = counterBufferVarLayoutBuilder.build(); + + IRStructTypeLayout::Builder layoutBuilder(&builder); + layoutBuilder.addField(elementBufferKey, elementBufferVarLayout); + layoutBuilder.addField(counterBufferKey, counterBufferVarLayout); + auto typeLayout = layoutBuilder.build(); + + builder.addLayoutDecoration(structType, typeLayout); + + IRFunc* appendFunc = nullptr; + IRFunc* consumeFunc = nullptr; + IRFunc* getDimensionsFunc = nullptr; + + if (type->getOp() == kIROp_HLSLAppendStructuredBufferType) + { + // Append method. + appendFunc = builder.createFunc(); + builder.addNameHintDecoration(appendFunc, UnownedStringSlice("AppendStructuredBuffer_Append")); + IRType* paramTypes[] = { structType, elementType }; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); + appendFunc->setFullType(funcType); + builder.setInsertInto(appendFunc); + builder.emitBlock(); + auto bufferParam = builder.emitParam(structType); + auto elementParam = builder.emitParam(elementType); + auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey); + auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey); + IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) }; + auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs); + auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterIncrement, 1, &counterBufferPtr); + + IRInst* getElementPtrArgs[] = { elementBuffer, oldCounter }; + auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs); + + builder.emitStore(elementBufferPtr, elementParam); + builder.emitReturn(); + } + else + { + // Consume method. + consumeFunc = builder.createFunc(); + builder.addNameHintDecoration(consumeFunc, UnownedStringSlice("ConsumeStructuredBuffer_Consume")); + IRType* paramTypes[] = { structType }; + auto funcType = builder.getFuncType(1, paramTypes, elementType); + consumeFunc->setFullType(funcType); + builder.setInsertInto(consumeFunc); + auto firstBlock = builder.emitBlock(); + auto bufferParam = builder.emitParam(structType); + auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey); + auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey); + IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) }; + auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs); + auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterDecrement, 1, &counterBufferPtr); + auto index = builder.emitSub(builder.getIntType(), oldCounter, builder.getIntValue(builder.getIntType(), 1)); + + // Test if index is greater or equal than 0. + auto geq = builder.emitGeq(index, builder.getIntValue(builder.getIntType(), 0)); + auto trueBlock = builder.emitBlock(); + + auto falseBlock = builder.emitBlock(); + auto mergeBlock = builder.emitBlock(); + + builder.setInsertInto(firstBlock); + builder.emitIfElse(geq, trueBlock, falseBlock, mergeBlock); + + builder.setInsertInto(trueBlock); + IRInst* getElementPtrArgs[] = { elementBuffer, index }; + auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs); + auto val = builder.emitLoad(elementBufferPtr); + builder.emitReturn(val); + + builder.setInsertInto(falseBlock); + auto defaultVal = builder.emitDefaultConstruct(elementType); + builder.emitReturn(defaultVal); + + builder.setInsertInto(mergeBlock); + builder.emitUnreachable(); + } + + // GetDimensions method. + { + getDimensionsFunc = builder.createFunc(); + builder.addNameHintDecoration(getDimensionsFunc, UnownedStringSlice("StructuredBuffer_GetDimensions")); + IRType* paramTypes[] = { structType }; + auto uint2Type = builder.getVectorType(builder.getUIntType(), 2); + auto funcType = builder.getFuncType(1, paramTypes, uint2Type); + getDimensionsFunc->setFullType(funcType); + builder.setInsertInto(getDimensionsFunc); + builder.emitBlock(); + auto bufferParam = builder.emitParam(structType); + auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey); + IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) }; + auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs); + auto counter = builder.emitLoad(counterBufferPtr); + counter = builder.emitCast(builder.getUIntType(), counter); + auto stride = builder.getIntValue(builder.getUIntType(), elementSize.getStride()); + IRInst* vecArgs[] = { counter, stride }; + builder.emitReturn(builder.emitMakeVector(uint2Type, 2, vecArgs)); + } + + // Replace all insts with synthesized functions. + traverseUsers(type, [&](IRInst* typeUser) + { + if (typeUser->getFullType() != type) + return; + if (auto layoutDecor = typeUser->findDecoration()) + { + // Replace the original StructuredBufferVarLayout with the new StructTypeVarLayout. + if (auto varLayout = as(layoutDecor->getLayout())) + { + IRBuilder subBuilder(typeUser); + IRVarLayout::Builder newVarLayoutBuilder(&subBuilder, typeLayout); + newVarLayoutBuilder.cloneEverythingButOffsetsFrom(varLayout); + for (auto offsetAttr : varLayout->getOffsetAttrs()) + { + auto info = newVarLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); + info->offset = offsetAttr->getOffset(); + info->space = offsetAttr->getSpace(); + info->kind = offsetAttr->getResourceKind(); + } + auto newVarLayout = newVarLayoutBuilder.build(); + subBuilder.addLayoutDecoration(typeUser, newVarLayout); + varLayout->removeAndDeallocate(); + } + } + traverseUses(typeUser, [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_StructuredBufferAppend: + { + IRBuilder subBuilder(user); + subBuilder.setInsertBefore(user); + IRInst* args[] = { user->getOperand(0), user->getOperand(1) }; + auto call = subBuilder.emitCallInst(user->getFullType(), appendFunc, 2, args); + user->replaceUsesWith(call); + user->removeAndDeallocate(); + break; + } + case kIROp_StructuredBufferConsume: + { + IRBuilder subBuilder(user); + subBuilder.setInsertBefore(user); + IRInst* args[] = { user->getOperand(0) }; + auto call = subBuilder.emitCallInst(user->getFullType(), consumeFunc, 1, args); + user->replaceUsesWith(call); + user->removeAndDeallocate(); + break; + } + case kIROp_StructuredBufferGetDimensions: + { + IRBuilder subBuilder(user); + subBuilder.setInsertBefore(user); + IRInst* args[] = { user->getOperand(0) }; + auto call = subBuilder.emitCallInst(user->getFullType(), getDimensionsFunc, 1, args); + user->replaceUsesWith(call); + user->removeAndDeallocate(); + break; + } + } + }); + }); + type->replaceUsesWith(structType); + } + + void lowerAppendConsumeStructuredBuffers(TargetRequest* target, IRModule* module, DiagnosticSink* sink) + { + SLANG_UNUSED(sink); + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_HLSLAppendStructuredBufferType: + case kIROp_HLSLConsumeStructuredBufferType: + lowerStructuredBufferType(target, as(globalInst)); + break; + } + } + } +} diff --git a/source/slang/slang-ir-lower-append-consume-structured-buffer.h b/source/slang/slang-ir-lower-append-consume-structured-buffer.h new file mode 100644 index 000000000..81048724d --- /dev/null +++ b/source/slang/slang-ir-lower-append-consume-structured-buffer.h @@ -0,0 +1,17 @@ +// slang-ir-lower-append-consume-structured-buffer.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + class TargetRequest; + + /// For non-hlsl targets, lower append- and consume- structured buffers into `struct` types + /// that contains two RWStructuredBuffer typed fields, one to store the elements, and one + /// for the atomic buffer. + void lowerAppendConsumeStructuredBuffers(TargetRequest* target, IRModule* module, DiagnosticSink* sink); + +} diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index cdf1f3694..978fa6fbb 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -715,6 +715,7 @@ static LayoutResourceKind _getHLSLLayoutResourceKind(ShaderParameterKind kind) case ShaderParameterKind::MutableRawBuffer: case ShaderParameterKind::MutableBuffer: case ShaderParameterKind::MutableTexture: + case ShaderParameterKind::AppendConsumeStructuredBuffer: return LayoutResourceKind::UnorderedAccess; case ShaderParameterKind::SamplerState: @@ -728,6 +729,13 @@ struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl { virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind, const Options& options) override { + int slotCount = 1; + + // In Vulkan GLSL, pretty much every object is just a descriptor-table slot. + // Except for AppendConsumeStructuredBuffer, which takes two slots. + if (kind == ShaderParameterKind::AppendConsumeStructuredBuffer) + slotCount = 2; + if (options.hlslToVulkanKindFlags) { // Is this an HLSL kind that might be shifted @@ -745,14 +753,12 @@ struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl { // We are going to consume a HLSL layout kind // Later we will do shifting as necessary - return SimpleLayoutInfo(hlslLayoutKind, 1); + return SimpleLayoutInfo(hlslLayoutKind, slotCount); } } } - // In Vulkan GLSL, pretty much every object is just a descriptor-table slot. - // We can refine this method once we support a case where this isn't true. - return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1); + return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, slotCount); } }; GLSLObjectLayoutRulesImpl kGLSLObjectLayoutRulesImpl; @@ -799,6 +805,7 @@ struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl case ShaderParameterKind::MutableRawBuffer: case ShaderParameterKind::MutableBuffer: case ShaderParameterKind::MutableTexture: + case ShaderParameterKind::AppendConsumeStructuredBuffer: return SimpleLayoutInfo(LayoutResourceKind::UnorderedAccess, 1); case ShaderParameterKind::SamplerState: @@ -974,6 +981,7 @@ struct CPUObjectLayoutRulesImpl : ObjectLayoutRulesImpl case ShaderParameterKind::StructuredBuffer: case ShaderParameterKind::MutableStructuredBuffer: + case ShaderParameterKind::AppendConsumeStructuredBuffer: // It's a ptr and a size of the amount of elements return SimpleLayoutInfo(LayoutResourceKind::Uniform, sizeof(void*) * 2, SLANG_ALIGN_OF(void*)); @@ -1033,6 +1041,7 @@ struct CUDAObjectLayoutRulesImpl : CPUObjectLayoutRulesImpl case ShaderParameterKind::StructuredBuffer: case ShaderParameterKind::MutableStructuredBuffer: + case ShaderParameterKind::AppendConsumeStructuredBuffer: { // It's a ptr and a count of the amount of elements const size_t size = _roundToAlignment(sizeof(CUDAPtr) + sizeof(CUDACount), sizeof(CUDAPtr)); @@ -3763,8 +3772,8 @@ static TypeLayoutResult _createTypeLayout( CASE(HLSLStructuredBufferType, StructuredBuffer); CASE(HLSLRWStructuredBufferType, MutableStructuredBuffer); CASE(HLSLRasterizerOrderedStructuredBufferType, MutableStructuredBuffer); - CASE(HLSLAppendStructuredBufferType, MutableStructuredBuffer); - CASE(HLSLConsumeStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLAppendStructuredBufferType, AppendConsumeStructuredBuffer); + CASE(HLSLConsumeStructuredBufferType, AppendConsumeStructuredBuffer); #undef CASE diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index af07f3e73..e3dd719d6 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -935,6 +935,8 @@ enum class ShaderParameterKind MutableImage, RegisterSpace, + + AppendConsumeStructuredBuffer, }; struct SimpleLayoutRulesImpl -- cgit v1.2.3