diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-03 12:52:26 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-03 19:52:26 +0000 |
| commit | 6a2cf239a89340ed2985d04609499e8c4a2d8f89 (patch) | |
| tree | c80f3c8dc7e89762aeab0ee7830d1ad728665460 /source/slang/slang-ir-wrap-cbuffer-element.cpp | |
| parent | cc8f6a241edb47c43c5698ee33abed4fe57d4566 (diff) | |
Fix legalization crash when processing metal parameter blocks. (#8591)
Closes #7606.
When Slang compile for a bindful target, we will run the resource type
legalization pass to hoist resource typed struct fields outside of the
struct type and define them as global parameters and passing them around
via dedicated function parameters.
When we compile for a bindless target, we don't run this pass.
However, Metal is a hybrid bindful and bindless target. We need to run
type legalization for the constant buffer, but skip type legalization
for parameter block.
The previous attempt to support this behavior is to hack the type
legalization pass to return `LegalVal::simple` when it sees a
`ParameterBlock<T>`. However, whenever the code is accessing
`parameterBlock.someNestedField`, the type of the nested field may get a
`LegalType::tuple`, and now we will run into inconsistent scenarios
where we have a `LegalVal::simple` on the operand val, and but the
legalization logic is expecting that val to be a `LegalType::tuple`.
This breaks a lot of assumptions and invariants in the type legalization
pass, resulting unstable/fragile behavior.
To systematically solve this problem, this change generalizes the
existing legalize buffer element type pass to translate
`ParameterBlock<Texture2D>` (and similar cases) to
`ParameterBlock<Texture2D.Handle>`. So that such parameter block will
always be legalized to `LegalType:::simple` during type legalization,
and we will never run into any inconsistent cases. This allowed us to
get rid of the hacky logic in the type legalization pass to try to
workaround the inconsistencies.
Diffstat (limited to 'source/slang/slang-ir-wrap-cbuffer-element.cpp')
| -rw-r--r-- | source/slang/slang-ir-wrap-cbuffer-element.cpp | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/source/slang/slang-ir-wrap-cbuffer-element.cpp b/source/slang/slang-ir-wrap-cbuffer-element.cpp new file mode 100644 index 000000000..9c070ab80 --- /dev/null +++ b/source/slang/slang-ir-wrap-cbuffer-element.cpp @@ -0,0 +1,133 @@ +#include "slang-ir-wrap-cbuffer-element.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +// This pass implements a simple translation that wraps the element type T in a ConstantBuffer<T> +// (or ParameterBlock<T>) type in `struct S { T inner; }`, and replace the ConstantBuffer<T> type +// with ConstantBuffer<S>. This is needed because some backends do not allow certain types to be +// used directly as the element type of a constant buffer. +// For example, Metal does not allow `ParameterBlock<StructuredBuffer<int>>` as that will create +// a double pointer that Metal compiler does not like. We can easily work around this limitation +// by wrapping the `StructuredBuffer<int>` in a struct. + +namespace Slang +{ + +void maybeProvideNameHint( + IRBuilder& builder, + IRStructType* wrappedStructType, + IRParameterGroupType* originalParamGroupType) +{ + StringBuilder sb; + sb << "wrapper_"; + getTypeNameHint(sb, originalParamGroupType->getElementType()); + builder.addNameHintDecoration(wrappedStructType, sb.produceString().getUnownedSlice()); +} + +void wrapCBufferElements(IRModule* module, WrapCBufferElementPolicy* policy) +{ + struct WorkItem + { + IRStructKey* wrappedFieldKey; + IRInst* inst; + IRInst* newParameterGroupType; + }; + + IRBuilder builder(module); + + List<WorkItem> workList; + for (auto globalInst : module->getGlobalInsts()) + { + // Discover all insts whose type is a parameter group type. + if (auto paramGroupType = as<IRParameterGroupType>(globalInst)) + { + if (!policy->shouldWrapBufferElementInStruct(paramGroupType)) + continue; + + // Create the wrapper struct. + builder.setInsertBefore(paramGroupType); + auto structType = builder.createStructType(); + maybeProvideNameHint(builder, structType, paramGroupType); + auto fieldKey = builder.createStructKey(); + builder.addNameHintDecoration(fieldKey, toSlice("inner")); + builder.createStructField(structType, fieldKey, paramGroupType->getElementType()); + + // Create the new parameter group type whose element is the wrapper struct. + List<IRInst*> bufferTypeOperands; + bufferTypeOperands.add(structType); + for (UInt i = 1; i < paramGroupType->getOperandCount(); ++i) + { + bufferTypeOperands.add(paramGroupType->getOperand(i)); + } + auto newParameterGroupType = builder.getType( + paramGroupType->getOp(), + (UInt)bufferTypeOperands.getCount(), + bufferTypeOperands.getArrayView().getBuffer()); + + // Traverse all uses of the parameter group type, and add them to the work list + // for further processing. + traverseUses( + paramGroupType, + [&](IRUse* use) + { + if (use->getUser()->getFullType() != paramGroupType) + return; + WorkItem item; + item.wrappedFieldKey = fieldKey; + item.inst = use->getUser(); + workList.add(item); + }); + paramGroupType->replaceUsesWith(newParameterGroupType); + } + } + + // Now we have a work list of all instructions that uses a parameter group. + // We need to update all uses of parameter group x with `x.inner` instead. + for (auto item : workList) + { + traverseUses( + item.inst, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + + // Note that we insert the field address instruction right before each use, instead + // of immediately after the original parameter group inst, because the parameter + // group inst may be defined in a scope that does not allow field address + // instructions. + auto unwrapped = builder.emitFieldAddress(item.inst, item.wrappedFieldKey); + builder.replaceOperand(use, unwrapped); + }); + } +} + +class MetalWrapCBufferElementPolicy : public WrapCBufferElementPolicy +{ +public: + virtual bool shouldWrapBufferElementInStruct(IRParameterGroupType* cbufferType) override + { + // Metal allows structs, scalars, vectors and matrices directly as buffer elements. + if (as<IRStructType>(cbufferType->getElementType())) + return false; + if (as<IRBasicType>(cbufferType->getElementType())) + return false; + if (as<IRMatrixType>(cbufferType->getElementType())) + return false; + if (as<IRVectorType>(cbufferType->getElementType())) + return false; + + // Wrap everything else in a struct. + return true; + } +}; + +void wrapCBufferElementsForMetal(IRModule* module) +{ + MetalWrapCBufferElementPolicy policy = {}; + wrapCBufferElements(module, &policy); +} + +} // namespace Slang |
