summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-21 17:07:34 -0700
committerGitHub <noreply@github.com>2023-08-21 17:07:34 -0700
commitbd6dbaf7c3ea720b4ed39904fe08878f9dcbd947 (patch)
tree9e8c436e0888d192c462f75e4655a63b51f41648 /source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
parentf94b2f7a328a898c5e3dc1389d08e0b7ce6e092e (diff)
Compile append and consume structured buffers to glsl. (#3142)
* Compile append and consume structured buffers to glsl. * Fix. * Update CI config. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-append-consume-structured-buffer.cpp')
-rw-r--r--source/slang/slang-ir-lower-append-consume-structured-buffer.cpp247
1 files changed, 247 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
new file mode 100644
index 000000000..fa9f16223
--- /dev/null
+++ b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
@@ -0,0 +1,247 @@
+#include "slang-ir-lower-append-consume-structured-buffer.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-layout.h"
+#include "slang-ir-lower-buffer-element-type.h"
+
+namespace Slang
+{
+ static void lowerStructuredBufferType(TargetRequest* target, IRHLSLStructuredBufferTypeBase* type)
+ {
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+
+ auto elementType = type->getElementType();
+
+ // Type.
+ auto structType = builder.createStructType();
+ StringBuilder nameSb;
+ if (type->getOp() == kIROp_HLSLAppendStructuredBufferType)
+ nameSb << "AppendStructuredBuffer_";
+ else
+ nameSb << "ConsumeStructuredBuffer_";
+ getTypeNameHint(nameSb, elementType);
+ nameSb << "_t";
+ builder.addNameHintDecoration(structType, nameSb.produceString().getUnownedSlice());
+
+ auto elementBufferKey = builder.createStructKey();
+ builder.addNameHintDecoration(elementBufferKey, UnownedStringSlice("elements"));
+
+ auto counterBufferKey = builder.createStructKey();
+ builder.addNameHintDecoration(counterBufferKey, UnownedStringSlice("counter"));
+
+ auto elementBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, elementType);
+ auto counterBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, builder.getIntType());
+
+ builder.createStructField(structType, elementBufferKey, elementBufferType);
+ builder.createStructField(structType, counterBufferKey, counterBufferType);
+
+ // Type layout.
+ auto layoutRules = getTypeLayoutRuleForBuffer(target, type);
+
+ IRTypeLayout::Builder elementTypeLayoutBuilder(&builder);
+ IRSizeAndAlignment elementSize;
+ getSizeAndAlignment(layoutRules, elementType, &elementSize);
+ elementTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize((LayoutSize::RawValue)elementSize.getStride()));
+ auto elementTypeLayout = elementTypeLayoutBuilder.build();
+
+ IRStructuredBufferTypeLayout::Builder elementBufferTypeLayoutBuilder(&builder, elementTypeLayout);
+ elementBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1);
+ auto elementBufferTypeLayout = elementBufferTypeLayoutBuilder.build();
+
+ IRTypeLayout::Builder counterTypeLayoutBuilder(&builder);
+ counterTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize(4));
+ auto counterTypeLayout = counterTypeLayoutBuilder.build();
+
+ IRStructuredBufferTypeLayout::Builder counterBufferTypeLayoutBuilder(&builder, counterTypeLayout);
+ counterBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1);
+ auto counterBufferTypeLayout = counterBufferTypeLayoutBuilder.build();
+
+ IRVarLayout::Builder elementBufferVarLayoutBuilder(&builder, elementBufferTypeLayout);
+ elementBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 0;
+ auto elementBufferVarLayout = elementBufferVarLayoutBuilder.build();
+
+ IRVarLayout::Builder counterBufferVarLayoutBuilder(&builder, counterBufferTypeLayout);
+ counterBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 1;
+ auto counterBufferVarLayout = counterBufferVarLayoutBuilder.build();
+
+ IRStructTypeLayout::Builder layoutBuilder(&builder);
+ layoutBuilder.addField(elementBufferKey, elementBufferVarLayout);
+ layoutBuilder.addField(counterBufferKey, counterBufferVarLayout);
+ auto typeLayout = layoutBuilder.build();
+
+ builder.addLayoutDecoration(structType, typeLayout);
+
+ IRFunc* appendFunc = nullptr;
+ IRFunc* consumeFunc = nullptr;
+ IRFunc* getDimensionsFunc = nullptr;
+
+ if (type->getOp() == kIROp_HLSLAppendStructuredBufferType)
+ {
+ // Append method.
+ appendFunc = builder.createFunc();
+ builder.addNameHintDecoration(appendFunc, UnownedStringSlice("AppendStructuredBuffer_Append"));
+ IRType* paramTypes[] = { structType, elementType };
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
+ appendFunc->setFullType(funcType);
+ builder.setInsertInto(appendFunc);
+ builder.emitBlock();
+ auto bufferParam = builder.emitParam(structType);
+ auto elementParam = builder.emitParam(elementType);
+ auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey);
+ auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
+ IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
+ auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
+ auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterIncrement, 1, &counterBufferPtr);
+
+ IRInst* getElementPtrArgs[] = { elementBuffer, oldCounter };
+ auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs);
+
+ builder.emitStore(elementBufferPtr, elementParam);
+ builder.emitReturn();
+ }
+ else
+ {
+ // Consume method.
+ consumeFunc = builder.createFunc();
+ builder.addNameHintDecoration(consumeFunc, UnownedStringSlice("ConsumeStructuredBuffer_Consume"));
+ IRType* paramTypes[] = { structType };
+ auto funcType = builder.getFuncType(1, paramTypes, elementType);
+ consumeFunc->setFullType(funcType);
+ builder.setInsertInto(consumeFunc);
+ auto firstBlock = builder.emitBlock();
+ auto bufferParam = builder.emitParam(structType);
+ auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey);
+ auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
+ IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
+ auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
+ auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterDecrement, 1, &counterBufferPtr);
+ auto index = builder.emitSub(builder.getIntType(), oldCounter, builder.getIntValue(builder.getIntType(), 1));
+
+ // Test if index is greater or equal than 0.
+ auto geq = builder.emitGeq(index, builder.getIntValue(builder.getIntType(), 0));
+ auto trueBlock = builder.emitBlock();
+
+ auto falseBlock = builder.emitBlock();
+ auto mergeBlock = builder.emitBlock();
+
+ builder.setInsertInto(firstBlock);
+ builder.emitIfElse(geq, trueBlock, falseBlock, mergeBlock);
+
+ builder.setInsertInto(trueBlock);
+ IRInst* getElementPtrArgs[] = { elementBuffer, index };
+ auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs);
+ auto val = builder.emitLoad(elementBufferPtr);
+ builder.emitReturn(val);
+
+ builder.setInsertInto(falseBlock);
+ auto defaultVal = builder.emitDefaultConstruct(elementType);
+ builder.emitReturn(defaultVal);
+
+ builder.setInsertInto(mergeBlock);
+ builder.emitUnreachable();
+ }
+
+ // GetDimensions method.
+ {
+ getDimensionsFunc = builder.createFunc();
+ builder.addNameHintDecoration(getDimensionsFunc, UnownedStringSlice("StructuredBuffer_GetDimensions"));
+ IRType* paramTypes[] = { structType };
+ auto uint2Type = builder.getVectorType(builder.getUIntType(), 2);
+ auto funcType = builder.getFuncType(1, paramTypes, uint2Type);
+ getDimensionsFunc->setFullType(funcType);
+ builder.setInsertInto(getDimensionsFunc);
+ builder.emitBlock();
+ auto bufferParam = builder.emitParam(structType);
+ auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
+ IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
+ auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
+ auto counter = builder.emitLoad(counterBufferPtr);
+ counter = builder.emitCast(builder.getUIntType(), counter);
+ auto stride = builder.getIntValue(builder.getUIntType(), elementSize.getStride());
+ IRInst* vecArgs[] = { counter, stride };
+ builder.emitReturn(builder.emitMakeVector(uint2Type, 2, vecArgs));
+ }
+
+ // Replace all insts with synthesized functions.
+ traverseUsers(type, [&](IRInst* typeUser)
+ {
+ if (typeUser->getFullType() != type)
+ return;
+ if (auto layoutDecor = typeUser->findDecoration<IRLayoutDecoration>())
+ {
+ // Replace the original StructuredBufferVarLayout with the new StructTypeVarLayout.
+ if (auto varLayout = as<IRVarLayout>(layoutDecor->getLayout()))
+ {
+ IRBuilder subBuilder(typeUser);
+ IRVarLayout::Builder newVarLayoutBuilder(&subBuilder, typeLayout);
+ newVarLayoutBuilder.cloneEverythingButOffsetsFrom(varLayout);
+ for (auto offsetAttr : varLayout->getOffsetAttrs())
+ {
+ auto info = newVarLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
+ info->offset = offsetAttr->getOffset();
+ info->space = offsetAttr->getSpace();
+ info->kind = offsetAttr->getResourceKind();
+ }
+ auto newVarLayout = newVarLayoutBuilder.build();
+ subBuilder.addLayoutDecoration(typeUser, newVarLayout);
+ varLayout->removeAndDeallocate();
+ }
+ }
+ traverseUses(typeUser, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_StructuredBufferAppend:
+ {
+ IRBuilder subBuilder(user);
+ subBuilder.setInsertBefore(user);
+ IRInst* args[] = { user->getOperand(0), user->getOperand(1) };
+ auto call = subBuilder.emitCallInst(user->getFullType(), appendFunc, 2, args);
+ user->replaceUsesWith(call);
+ user->removeAndDeallocate();
+ break;
+ }
+ case kIROp_StructuredBufferConsume:
+ {
+ IRBuilder subBuilder(user);
+ subBuilder.setInsertBefore(user);
+ IRInst* args[] = { user->getOperand(0) };
+ auto call = subBuilder.emitCallInst(user->getFullType(), consumeFunc, 1, args);
+ user->replaceUsesWith(call);
+ user->removeAndDeallocate();
+ break;
+ }
+ case kIROp_StructuredBufferGetDimensions:
+ {
+ IRBuilder subBuilder(user);
+ subBuilder.setInsertBefore(user);
+ IRInst* args[] = { user->getOperand(0) };
+ auto call = subBuilder.emitCallInst(user->getFullType(), getDimensionsFunc, 1, args);
+ user->replaceUsesWith(call);
+ user->removeAndDeallocate();
+ break;
+ }
+ }
+ });
+ });
+ type->replaceUsesWith(structType);
+ }
+
+ void lowerAppendConsumeStructuredBuffers(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
+ {
+ SLANG_UNUSED(sink);
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_HLSLAppendStructuredBufferType:
+ case kIROp_HLSLConsumeStructuredBufferType:
+ lowerStructuredBufferType(target, as<IRHLSLStructuredBufferTypeBase>(globalInst));
+ break;
+ }
+ }
+ }
+}