diff options
| author | Yong He <yonghe@outlook.com> | 2024-11-06 10:58:09 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-06 10:58:09 -0800 |
| commit | b86703432629bbfd75a902671d15e40c591065a7 (patch) | |
| tree | 2109f87304fa534034f01812551f29be098c4710 /source | |
| parent | f8294202ce8d5658f6308eeaed634058db9bbb4b (diff) | |
[WGSL] Enable arbitrary arrays in uniform buffers. (#5497)
* [WGSL] Enable arbitrary arrays in uniform buffers.
* format code
* Undo irrelevant change and fixups.
* Update expected failure list.
* Fix.
* Rename.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 353 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 |
6 files changed, 281 insertions, 100 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c23195f7c..1950f251c 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1455,7 +1455,10 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.meshOutput) legalizeMeshOutputTypes(irModule); - lowerBufferElementTypeToStorageType(targetProgram, irModule); + BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; + bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer = + isWGPUTarget(targetRequest); + lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions); // Rewrite functions that return arrays to return them via `out` parameter, // since our target languages doesn't allow returning arrays. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9a081f9de..4baa786e3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3566,6 +3566,7 @@ public: IRInst* replaceOperand(IRUse* use, IRInst* newValue); IRInst* getBoolValue(bool value); + IRInst* getIntValue(IRIntegerValue value); IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); IRStringLit* getStringValue(const UnownedStringSlice& slice); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 7d722bf9c..f9017ebe1 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -12,6 +12,47 @@ struct LoweredElementTypeContext { static const IRIntegerValue kMaxArraySizeToUnroll = 32; + enum ConversionMethodKind + { + Func, + Opcode + }; + struct ConversionMethod + { + ConversionMethodKind kind = ConversionMethodKind::Func; + union + { + IRFunc* func; + IROp op; + }; + ConversionMethod() { func = nullptr; } + operator bool() + { + return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop; + } + ConversionMethod& operator=(IRFunc* f) + { + kind = ConversionMethodKind::Func; + this->func = f; + return *this; + } + ConversionMethod& operator=(IROp irop) + { + kind = ConversionMethodKind::Opcode; + this->op = irop; + return *this; + } + IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operand) + { + if (!*this) + return operand; + if (kind == ConversionMethodKind::Func) + return builder.emitCallInst(resultType, func, 1, &operand); + else + return builder.emitIntrinsicInst(resultType, op, 1, &operand); + } + }; + struct LoweredElementTypeInfo { IRType* originalType; @@ -22,25 +63,48 @@ struct LoweredElementTypeContext IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the // struct key of the data field. - IRFunc* convertOriginalToLowered = nullptr; - IRFunc* convertLoweredToOriginal = nullptr; + ConversionMethod convertOriginalToLowered; + ConversionMethod convertLoweredToOriginal; }; Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count]; Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count]; + struct ConversionMethodKey + { + IRType* toType; + IRType* fromType; + bool operator==(const ConversionMethodKey& other) const + { + return toType == other.toType && fromType == other.fromType; + } + HashCode64 getHashCode() const + { + return combineHash(Slang::getHashCode(toType), Slang::getHashCode(fromType)); + } + }; + + Dictionary<ConversionMethodKey, ConversionMethod> conversionMethodMap; + ConversionMethod getConversionMethod(IRType* toType, IRType* fromType) + { + ConversionMethodKey key; + key.toType = toType; + key.fromType = fromType; + ConversionMethod method; + conversionMethodMap.tryGetValue(key, method); + return method; + } + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; TargetProgram* target; - bool lowerBufferPointer = false; + BufferElementTypeLoweringOptions options; LoweredElementTypeContext( TargetProgram* target, - bool lowerBufferPointer, + BufferElementTypeLoweringOptions inOptions, SlangMatrixLayoutMode inDefaultMatrixLayout) - : target(target) - , defaultMatrixLayout(inDefaultMatrixLayout) - , lowerBufferPointer(lowerBufferPointer) + : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions) { } @@ -133,6 +197,11 @@ struct LoweredElementTypeContext auto element = elements[(Index)(r * colCount + c)]; vecArgs.add(element); } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } auto colVector = builder.emitMakeVector( vectorType, (UInt)vecArgs.getCount(), @@ -150,6 +219,11 @@ struct LoweredElementTypeContext auto element = elements[(Index)(r * colCount + c)]; vecArgs.add(element); } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } auto rowVector = builder.emitMakeVector( vectorType, (UInt)vecArgs.getCount(), @@ -192,13 +266,10 @@ struct LoweredElementTypeContext for (IRIntegerValue ii = 0; ii < count; ++ii) { auto packedElement = builder.emitElementExtract(packedArray, ii); - auto originalElement = innerTypeInfo.convertLoweredToOriginal - ? builder.emitCallInst( - innerTypeInfo.originalType, - innerTypeInfo.convertLoweredToOriginal, - 1, - &packedElement) - : packedElement; + auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( + builder, + innerTypeInfo.originalType, + packedElement); args[(Index)ii] = originalElement; } result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); @@ -218,13 +289,10 @@ struct LoweredElementTypeContext builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); auto packedElement = builder.emitElementExtract(packedArray, loopParam); - auto originalElement = innerTypeInfo.convertLoweredToOriginal - ? builder.emitCallInst( - innerTypeInfo.originalType, - innerTypeInfo.convertLoweredToOriginal, - 1, - &packedElement) - : packedElement; + auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( + builder, + innerTypeInfo.originalType, + packedElement); auto varPtr = builder.emitElementAddress(resultVar, loopParam); builder.emitStore(varPtr, originalElement); builder.setInsertInto(loopBreakBlock); @@ -259,13 +327,10 @@ struct LoweredElementTypeContext for (IRIntegerValue ii = 0; ii < count; ++ii) { auto originalElement = builder.emitElementExtract(originalParam, ii); - auto packedElement = innerTypeInfo.convertOriginalToLowered - ? builder.emitCallInst( - innerTypeInfo.loweredType, - innerTypeInfo.convertOriginalToLowered, - 1, - &originalElement) - : originalElement; + auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + builder, + innerTypeInfo.loweredType, + originalElement); args[(Index)ii] = packedElement; } packedArray = @@ -286,13 +351,10 @@ struct LoweredElementTypeContext builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); auto originalElement = builder.emitElementExtract(originalParam, loopParam); - auto packedElement = innerTypeInfo.convertOriginalToLowered - ? builder.emitCallInst( - innerTypeInfo.loweredType, - innerTypeInfo.convertOriginalToLowered, - 1, - &originalElement) - : originalElement; + auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + builder, + innerTypeInfo.loweredType, + originalElement); auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); builder.emitStore(varPtr, packedElement); builder.setInsertInto(loopBreakBlock); @@ -319,6 +381,17 @@ struct LoweredElementTypeContext } } + // Returns the number of elements N that ensures the IRVectorType(elementType,N) + // has 16-byte aligned size and N is no less than `minCount`. + IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount) + { + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment); + if (sizeAlignment.size) + return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size; + return 4; + } + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules) { IRBuilder builder(type); @@ -357,9 +430,18 @@ struct LoweredElementTypeContext builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); auto structKey = builder.createStructKey(); builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - auto vectorType = builder.getVectorType( - matrixType->getElementType(), - isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount()); + auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); + if (rules->ruleName == IRTypeLayoutRuleName::Std140 && + options.use16ByteArrayElementForConstantBuffer) + { + // For constant buffer layout, we need to use 16-byte aligned vector if + // we are required to ensure array element types has 16-byte stride. + vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount( + matrixType->getElementType(), + getIntVal(vectorSize))); + } + + auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment); elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); @@ -382,6 +464,52 @@ struct LoweredElementTypeContext else if (auto arrayType = as<IRArrayType>(type)) { auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); + + if (rules->ruleName == IRTypeLayoutRuleName::Std140 && + options.use16ByteArrayElementForConstantBuffer) + { + // For constant buffer layout, we need to use 16-byte-aligned vector if + // we are required to ensure array element types has 16-byte stride. + // We only need to handle the case where the element type is a scalar or vector + // type here, because if the element type is a matrix type or struct type, + // the size promotion will be handled during lowering of the element type. + IRType* packedVectorType = nullptr; + if (auto vectorType = as<IRVectorType>(loweredInnerTypeInfo.loweredType)) + { + packedVectorType = builder.getVectorType( + vectorType->getElementType(), + builder.getIntValue(get16ByteAlignedVectorElementCount( + vectorType->getElementType(), + getIntVal(vectorType->getElementCount())))); + if (packedVectorType != loweredInnerTypeInfo.originalType) + { + loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape; + loweredInnerTypeInfo.convertOriginalToLowered = kIROp_VectorReshape; + } + } + else if (auto scalarType = as<IRBasicType>(loweredInnerTypeInfo.loweredType)) + { + packedVectorType = builder.getVectorType( + loweredInnerTypeInfo.loweredType, + get16ByteAlignedVectorElementCount(scalarType, 1)); + loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape; + loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar; + } + if (packedVectorType) + { + loweredInnerTypeInfo.loweredType = packedVectorType; + if (loweredInnerTypeInfo.convertLoweredToOriginal) + conversionMethodMap[ConversionMethodKey{ + packedVectorType, + loweredInnerTypeInfo.originalType}] = + loweredInnerTypeInfo.convertOriginalToLowered; + if (loweredInnerTypeInfo.convertOriginalToLowered) + conversionMethodMap[ConversionMethodKey{ + loweredInnerTypeInfo.originalType, + packedVectorType}] = loweredInnerTypeInfo.convertLoweredToOriginal; + } + } + // For spirv backend, we always want to lower all array types, even if the element type // comes out the same. This is because different layout rules may have different array // stride requirements. @@ -393,6 +521,7 @@ struct LoweredElementTypeContext return info; } } + auto loweredType = builder.createStructType(); info.loweredType = loweredType; StringBuilder nameSB; @@ -486,11 +615,11 @@ struct LoweredElementTypeContext { builder.setInsertAfter(loweredType); info.convertLoweredToOriginal = builder.createFunc(); - builder.setInsertInto(info.convertLoweredToOriginal); + builder.setInsertInto(info.convertLoweredToOriginal.func); builder.addNameHintDecoration( - info.convertLoweredToOriginal, + info.convertLoweredToOriginal.func, UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal->setFullType( + info.convertLoweredToOriginal.func->setFullType( builder.getFuncType(1, (IRType**)&loweredType, type)); builder.emitBlock(); auto loweredParam = builder.emitParam(loweredType); @@ -508,13 +637,10 @@ struct LoweredElementTypeContext loweredParam, field->getKey()); auto unpackedField = - fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal - ? builder.emitCallInst( - field->getFieldType(), - fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, - 1, - &storageField) - : storageField; + fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal.apply( + builder, + field->getFieldType(), + storageField); args.add(unpackedField); fieldId++; } @@ -524,13 +650,13 @@ struct LoweredElementTypeContext // Create pack func. { - builder.setInsertAfter(info.convertLoweredToOriginal); + builder.setInsertAfter(info.convertLoweredToOriginal.func); info.convertOriginalToLowered = builder.createFunc(); - builder.setInsertInto(info.convertOriginalToLowered); + builder.setInsertInto(info.convertOriginalToLowered.func); builder.addNameHintDecoration( - info.convertOriginalToLowered, + info.convertOriginalToLowered.func, UnownedStringSlice("packStorage")); - info.convertOriginalToLowered->setFullType( + info.convertOriginalToLowered.func->setFullType( builder.getFuncType(1, (IRType**)&type, loweredType)); builder.emitBlock(); auto param = builder.emitParam(type); @@ -545,14 +671,10 @@ struct LoweredElementTypeContext } auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); - auto packedField = - fieldLoweredTypeInfo[fieldId].convertOriginalToLowered - ? builder.emitCallInst( - fieldLoweredTypeInfo[fieldId].loweredType, - fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, - 1, - &fieldVal) - : fieldVal; + auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.apply( + builder, + fieldLoweredTypeInfo[fieldId].loweredType, + fieldVal); args.add(packedField); fieldId++; } @@ -587,11 +709,11 @@ struct LoweredElementTypeContext { builder.setInsertAfter(type); info.convertLoweredToOriginal = builder.createFunc(); - builder.setInsertInto(info.convertLoweredToOriginal); + builder.setInsertInto(info.convertLoweredToOriginal.func); builder.addNameHintDecoration( - info.convertLoweredToOriginal, + info.convertLoweredToOriginal.func, UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal->setFullType( + info.convertLoweredToOriginal.func->setFullType( builder.getFuncType(1, (IRType**)&info.loweredType, type)); builder.emitBlock(); auto loweredParam = builder.emitParam(info.loweredType); @@ -601,13 +723,13 @@ struct LoweredElementTypeContext // Create pack func. { - builder.setInsertAfter(info.convertLoweredToOriginal); + builder.setInsertAfter(info.convertLoweredToOriginal.func); info.convertOriginalToLowered = builder.createFunc(); - builder.setInsertInto(info.convertOriginalToLowered); + builder.setInsertInto(info.convertOriginalToLowered.func); builder.addNameHintDecoration( - info.convertOriginalToLowered, + info.convertOriginalToLowered.func, UnownedStringSlice("packStorage")); - info.convertOriginalToLowered->setFullType( + info.convertOriginalToLowered.func->setFullType( builder.getFuncType(1, (IRType**)&type, info.loweredType)); builder.emitBlock(); auto param = builder.emitParam(type); @@ -644,6 +766,8 @@ struct LoweredElementTypeContext getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment); loweredTypeInfo[(int)rules->ruleName].set(type, info); mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info); + conversionMethodMap[{info.originalType, info.loweredType}] = info.convertLoweredToOriginal; + conversionMethodMap[{info.loweredType, info.originalType}] = info.convertOriginalToLowered; return info; } @@ -693,7 +817,7 @@ struct LoweredElementTypeContext for (auto globalInst : module->getGlobalInsts()) { IRType* elementType = nullptr; - if (lowerBufferPointer) + if (options.lowerBufferPointer) { if (auto ptrType = as<IRPtrType>(globalInst)) { @@ -729,7 +853,8 @@ struct LoweredElementTypeContext auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules); // If the lowered type is the same as original type, no change is required. - if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) + if (loweredBufferElementTypeInfo.loweredType == + loweredBufferElementTypeInfo.originalType) continue; builder.setInsertBefore(bufferType); @@ -862,8 +987,41 @@ struct LoweredElementTypeContext } } - auto loweredElementTypeInfo = - getLoweredTypeInfo((IRType*)originalElementType, layoutRules); + LoweredElementTypeInfo loweredElementTypeInfo = {}; + if (auto getElementPtr = as<IRGetElementPtr>(ptrVal)) + { + if (auto arrayType = as<IRArrayTypeBase>( + tryGetPointedToType(&builder, getElementPtr->getBase()->getDataType()))) + { + // For WGSL, an array of scalar or vector type will always be converted to + // an array of 16-byte aligned vector type. In this case, we will run into a + // GetElementPtr where the result type is different from the element type of + // the base array. + // We should setup loweredElementTypeInfo so the remaining logic can handle + // this case and insert proper packing/unpacking logic around it. + if (arrayType->getElementType() != originalElementType) + { + loweredElementTypeInfo.loweredType = arrayType->getElementType(); + loweredElementTypeInfo.originalType = (IRType*)originalElementType; + loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod( + loweredElementTypeInfo.originalType, + loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod( + loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.originalType); + } + } + } + + // For general cases we simply check if the element type needs lowering. + // If so we will insert packing/unpacking logic if necessary. + // + if (!loweredElementTypeInfo.loweredType) + { + loweredElementTypeInfo = + getLoweredTypeInfo((IRType*)originalElementType, layoutRules); + } + if (!loweredElementTypeInfo.convertLoweredToOriginal) continue; @@ -891,11 +1049,11 @@ struct LoweredElementTypeContext builder.setInsertBefore(user); auto newLoad = cloneInst(&cloneEnv, &builder, user); newLoad->setFullType(loweredElementTypeInfo.loweredType); - auto unpackedVal = builder.emitCallInst( - (IRType*)originalElementType, - loweredElementTypeInfo.convertLoweredToOriginal, - 1, - &newLoad); + auto unpackedVal = + loweredElementTypeInfo.convertLoweredToOriginal.apply( + builder, + loweredElementTypeInfo.originalType, + newLoad); user->replaceUsesWith(unpackedVal); user->removeAndDeallocate(); break; @@ -910,11 +1068,11 @@ struct LoweredElementTypeContext IRCloneEnv cloneEnv = {}; builder.setInsertBefore(user); auto originalVal = getStoreVal(user); - auto packedVal = builder.emitCallInst( - loweredElementTypeInfo.loweredType, - loweredElementTypeInfo.convertOriginalToLowered, - 1, - &originalVal); + auto packedVal = + loweredElementTypeInfo.convertOriginalToLowered.apply( + builder, + loweredElementTypeInfo.loweredType, + originalVal); if (auto store = as<IRStore>(user)) store->val.set(packedVal); else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) @@ -954,9 +1112,9 @@ struct LoweredElementTypeContext } else { - // If we getting a derived address from the pointer, we need to - // recursively lower the new address. We do so by pushing the - // address inst into the work list. + // If we getting a derived address from the pointer, we need + // to recursively lower the new address. We do so by pushing + // the address inst into the work list. ptrValsWorkList.add(user); } } @@ -973,7 +1131,8 @@ struct LoweredElementTypeContext // an argument, we don't need to do any marshalling here. if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType())) break; - if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType())) + if (options.lowerBufferPointer && + as<IRPtrType>(ptrVal->getDataType())) break; // If we are calling a function with an l-value pointer from buffer // access, we need to materialize the object as a local variable, @@ -981,21 +1140,21 @@ struct LoweredElementTypeContext builder.setInsertBefore(user); auto newLoad = builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal); - auto unpackedVal = builder.emitCallInst( - (IRType*)originalElementType, - loweredElementTypeInfo.convertLoweredToOriginal, - 1, - &newLoad); + auto unpackedVal = + loweredElementTypeInfo.convertLoweredToOriginal.apply( + builder, + (IRType*)originalElementType, + newLoad); auto var = builder.emitVar((IRType*)originalElementType); builder.emitStore(var, unpackedVal); use->set(var); builder.setInsertAfter(user); auto newVal = builder.emitLoad(var); - auto packedVal = builder.emitCallInst( - (IRType*)loweredElementTypeInfo.loweredType, - loweredElementTypeInfo.convertOriginalToLowered, - 1, - &newVal); + auto packedVal = + loweredElementTypeInfo.convertOriginalToLowered.apply( + builder, + (IRType*)loweredElementTypeInfo.loweredType, + newVal); builder.emitStore(ptrVal, packedVal); } break; @@ -1148,7 +1307,7 @@ struct LoweredElementTypeContext void lowerBufferElementTypeToStorageType( TargetProgram* target, IRModule* module, - bool lowerBufferPointer) + BufferElementTypeLoweringOptions options) { SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); @@ -1157,7 +1316,7 @@ void lowerBufferElementTypeToStorageType( defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode); + LoweredElementTypeContext context(target, options, defaultMatrixMode); context.processModule(module); } diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h index d6082798f..2c69c5476 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.h +++ b/source/slang/slang-ir-lower-buffer-element-type.h @@ -8,6 +8,14 @@ class TargetProgram; struct IRTypeLayoutRules; struct IRType; +struct BufferElementTypeLoweringOptions +{ + bool lowerBufferPointer = false; + + // For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers. + bool use16ByteArrayElementForConstantBuffer = false; +}; + // For each struct type S used as element type of a ConstantBuffer, ParameterBlock or // [RW]StructuredBuffer, we create a lowered type L, where matrix types are lowered to arrays of // vectors based on major-ness, and loads from the buffer are converted to L_to_S(load(buffer)), and @@ -18,7 +26,7 @@ struct IRType; void lowerBufferElementTypeToStorageType( TargetProgram* target, IRModule* module, - bool lowerBufferPointer = false); + BufferElementTypeLoweringOptions options = BufferElementTypeLoweringOptions()); // Returns the type layout rules should be used for a buffer resource type. diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index ff1ddadca..4baa28d67 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2205,7 +2205,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // pointers in this pass. In the future we should consider separate out IRAddress as // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can // safely lower the pointer load stores early together with other buffer types. - lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true); + BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; + bufferElementTypeLoweringOptions.lowerBufferPointer = true; + lowerBufferElementTypeToStorageType( + m_sharedContext->m_targetProgram, + m_module, + bufferElementTypeLoweringOptions); // The above step may produce empty struct types, so we need to lower them out of existence. legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2cbafea6c..823b3cd7d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2264,6 +2264,11 @@ IRInst* IRBuilder::getBoolValue(bool inValue) return _findOrEmitConstant(keyInst); } +IRInst* IRBuilder::getIntValue(IRIntegerValue value) +{ + return getIntValue(getIntType(), value); +} + IRInst* IRBuilder::getIntValue(IRType* type, IRIntegerValue inValue) { IRConstant keyInst; |
