diff options
| author | Yong He <yonghe@outlook.com> | 2025-03-06 14:26:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-06 14:26:34 -0800 |
| commit | 4485cf3eaf142cfd5f8470e86739acc67d4e12ea (patch) | |
| tree | c6ce220dfe5f3ab25ea558f2512f3761c9565c69 /source/slang/slang-ir-lower-buffer-element-type.cpp | |
| parent | 55dd2deaff82bbdb72e125ba4b350030b7e5f427 (diff) | |
Update SPIRV-Tools and fix new validation errors. (#6511)
* Update SPIRV-Tools and fix new validation errors.
* Implement pointers for glsl target.
* Reworked packStorage/unpackStorage code gen to operate on pointers rather than values.
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 431 |
1 files changed, 260 insertions, 171 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 74e84f1ee..6f0e22a57 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -58,14 +58,39 @@ struct LoweredElementTypeContext this->op = irop; return *this; } - IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operand) + IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) { if (!*this) - return operand; + return builder.emitLoad(operandAddr); if (kind == ConversionMethodKind::Func) - return builder.emitCallInst(resultType, func, 1, &operand); + return builder.emitCallInst(resultType, func, 1, &operandAddr); else - return builder.emitIntrinsicInst(resultType, op, 1, &operand); + { + auto val = builder.emitLoad(operandAddr); + return builder.emitIntrinsicInst(resultType, op, 1, &val); + } + } + void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) + { + if (!*this) + { + builder.emitStore(dest, operand); + return; + } + if (kind == ConversionMethodKind::Func) + { + IRInst* operands[] = {dest, operand}; + builder.emitCallInst(builder.getVoidType(), func, 2, operands); + } + else + { + auto val = builder.emitIntrinsicInst( + tryGetPointedToType(&builder, dest->getDataType()), + op, + 1, + &operand); + builder.emitStore(dest, val); + } } }; @@ -131,21 +156,23 @@ struct LoweredElementTypeContext IRFunc* createMatrixUnpackFunc( IRMatrixType* matrixType, IRStructType* structType, - IRStructKey* dataKey, - IRArrayType* arrayType) + IRStructKey* dataKey) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + auto refStructType = builder.getRefType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); auto rowCount = (Index)getIntVal(matrixType->getRowCount()); auto colCount = (Index)getIntVal(matrixType->getColumnCount()); - auto packedParam = builder.emitParam(structType); - auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + auto packedParamRef = builder.emitParam(refStructType); + auto packedParam = builder.emitLoad(packedParamRef); + auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); List<IRInst*> args; args.setCount(rowCount * colCount); if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) @@ -187,13 +214,17 @@ struct LoweredElementTypeContext IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + auto outStructType = builder.getRefType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outStructType, matrixType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); + auto outParam = builder.emitParam(outStructType); auto originalParam = builder.emitParam(matrixType); List<IRInst*> elements; elements.setCount((Index)(rowCount * colCount)); @@ -255,7 +286,8 @@ struct LoweredElementTypeContext auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); auto result = builder.emitMakeStruct(structType, 1, &vectorArray); - builder.emitReturn(result); + builder.emitStore(outParam, result); + builder.emitReturn(); return func; } @@ -263,19 +295,20 @@ struct LoweredElementTypeContext IRArrayType* arrayType, IRStructType* structType, IRStructKey* dataKey, - IRArrayType* innerArrayType, LoweredElementTypeInfo innerTypeInfo) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + auto refStructType = builder.getRefType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, arrayType); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); - auto packedParam = builder.emitParam(structType); - auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); + auto packedParam = builder.emitParam(refStructType); + auto packedArray = builder.emitFieldAddress(packedParam, dataKey); auto count = getIntVal(arrayType->getElementCount()); IRInst* result = nullptr; if (count <= kMaxArraySizeToUnroll) @@ -285,11 +318,11 @@ struct LoweredElementTypeContext args.setCount((Index)count); for (IRIntegerValue ii = 0; ii < count; ++ii) { - auto packedElement = builder.emitElementExtract(packedArray, ii); + auto packedElementAddr = builder.emitElementAddress(packedArray, ii); auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( builder, innerTypeInfo.originalType, - packedElement); + packedElementAddr); args[(Index)ii] = originalElement; } result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); @@ -308,11 +341,11 @@ struct LoweredElementTypeContext loopBreakBlock); builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); - auto packedElement = builder.emitElementExtract(packedArray, loopParam); + auto packedElementAddr = builder.emitElementAddress(packedArray, loopParam); auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( builder, innerTypeInfo.originalType, - packedElement); + packedElementAddr); auto varPtr = builder.emitElementAddress(resultVar, loopParam); builder.emitStore(varPtr, originalElement); builder.setInsertInto(loopBreakBlock); @@ -325,20 +358,24 @@ struct LoweredElementTypeContext IRFunc* createArrayPackFunc( IRArrayType* arrayType, IRStructType* structType, - IRArrayType* innerArrayType, + IRStructKey* arrayStructKey, LoweredElementTypeInfo innerTypeInfo) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&arrayType, structType); + auto outLoweredType = builder.getRefType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outLoweredType, structType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); + auto outParam = builder.emitParam(outLoweredType); auto originalParam = builder.emitParam(arrayType); - IRInst* packedArray = nullptr; auto count = getIntVal(arrayType->getElementCount()); + auto destArray = builder.emitFieldAddress(outParam, arrayStructKey); if (count <= kMaxArraySizeToUnroll) { // If the array is small enough, just process each element directly. @@ -347,19 +384,16 @@ struct LoweredElementTypeContext for (IRIntegerValue ii = 0; ii < count; ++ii) { auto originalElement = builder.emitElementExtract(originalParam, ii); - auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + auto destArrayElement = builder.emitElementAddress(destArray, ii); + innerTypeInfo.convertOriginalToLowered.applyDestinationDriven( builder, - innerTypeInfo.loweredType, + destArrayElement, originalElement); - args[(Index)ii] = packedElement; } - packedArray = - builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); } else { // The general case for large arrays is to emit a loop through the elements. - IRVar* packedArrayVar = builder.emitVar(innerArrayType); IRBlock* loopBodyBlock; IRBlock* loopBreakBlock; auto loopParam = emitLoopBlocks( @@ -371,18 +405,14 @@ struct LoweredElementTypeContext builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); auto originalElement = builder.emitElementExtract(originalParam, loopParam); - auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + auto varPtr = builder.emitElementAddress(destArray, loopParam); + innerTypeInfo.convertOriginalToLowered.applyDestinationDriven( builder, - innerTypeInfo.loweredType, + varPtr, originalElement); - auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); - builder.emitStore(varPtr, packedElement); builder.setInsertInto(loopBreakBlock); - packedArray = builder.emitLoad(packedArrayVar); } - - auto result = builder.emitMakeStruct(structType, 1, &packedArray); - builder.emitReturn(result); + builder.emitReturn(); return func; } @@ -451,6 +481,8 @@ struct LoweredElementTypeContext } auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + StringBuilder nameSB; bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; @@ -494,14 +526,14 @@ struct LoweredElementTypeContext info.loweredInnerArrayType = arrayType; info.loweredInnerStructKey = structKey; info.convertLoweredToOriginal = - createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + createMatrixUnpackFunc(matrixType, loweredType, structKey); info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); return info; } - else if (auto arrayType = as<IRArrayType>(type)) + else if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) { - auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config); + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config); if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && options.use16ByteArrayElementForConstantBuffer) @@ -560,42 +592,59 @@ struct LoweredElementTypeContext } } - auto loweredType = builder.createStructType(); - info.loweredType = loweredType; - StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; - getTypeNameHint(nameSB, arrayType->getElementType()); - nameSB << getIntVal(arrayType->getElementCount()); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - loweredInnerTypeInfo.loweredType, - &elementSizeAlignment); - elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); - auto innerArrayType = builder.getArrayType( - loweredInnerTypeInfo.loweredType, - arrayType->getElementCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, innerArrayType); - info.loweredInnerArrayType = innerArrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = createArrayUnpackFunc( - arrayType, - loweredType, - structKey, - innerArrayType, - loweredInnerTypeInfo); - info.convertOriginalToLowered = - createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); - return info; - } - else if (as<IRArrayTypeBase>(type)) - { - info.loweredType = builder.getVoidType(); + auto arrayType = as<IRArrayType>(arrayTypeBase); + if (arrayType) + { + auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration( + loweredType, + nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerTypeInfo.loweredType, + &elementSizeAlignment); + elementSizeAlignment = + config.layoutRule->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayType( + loweredInnerTypeInfo.loweredType, + arrayType->getElementCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = + createArrayUnpackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo); + info.convertOriginalToLowered = + createArrayPackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo); + } + else + { + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerTypeInfo.loweredType, + &elementSizeAlignment); + elementSizeAlignment = + config.layoutRule->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayTypeBase( + arrayTypeBase->getOp(), + loweredInnerTypeInfo.loweredType, + nullptr, + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + info.loweredType = innerArrayType; + } return info; } else if (auto structType = as<IRStructType>(type)) @@ -625,6 +674,8 @@ struct LoweredElementTypeContext } } auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + StringBuilder nameSB; getTypeNameHint(nameSB, type); nameSB << "_" << getLayoutName(config.layoutRule->ruleName); @@ -635,12 +686,15 @@ struct LoweredElementTypeContext Index fieldId = 0; for (auto field : structType->getFields()) { - if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) + auto& loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; + // When lowering type for user pointer, skip fields that are unsized array. + if (config.addressSpace == AddressSpace::UserPointer && + as<IRUnsizedArrayType>(loweredFieldTypeInfo.loweredType)) { fieldId++; + loweredFieldTypeInfo.loweredType = builder.getVoidType(); continue; } - auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; builder.createStructField( loweredType, field->getKey(), @@ -657,10 +711,12 @@ struct LoweredElementTypeContext builder.addNameHintDecoration( info.convertLoweredToOriginal.func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(info.convertLoweredToOriginal.func); + auto refLoweredType = builder.getRefType(loweredType, AddressSpace::Generic); info.convertLoweredToOriginal.func->setFullType( - builder.getFuncType(1, (IRType**)&loweredType, type)); + builder.getFuncType(1, (IRType**)&refLoweredType, type)); builder.emitBlock(); - auto loweredParam = builder.emitParam(loweredType); + auto loweredParam = builder.emitParam(refLoweredType); List<IRInst*> args; Index fieldId = 0; for (auto field : structType->getFields()) @@ -670,10 +726,7 @@ struct LoweredElementTypeContext fieldId++; continue; } - auto storageField = builder.emitFieldExtract( - fieldLoweredTypeInfo[fieldId].loweredType, - loweredParam, - field->getKey()); + auto storageField = builder.emitFieldAddress(loweredParam, field->getKey()); auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal.apply( builder, @@ -694,9 +747,14 @@ struct LoweredElementTypeContext builder.addNameHintDecoration( info.convertOriginalToLowered.func, UnownedStringSlice("packStorage")); + builder.addForceInlineDecoration(info.convertOriginalToLowered.func); + + auto outLoweredType = builder.getRefType(loweredType, AddressSpace::Generic); + IRType* paramTypes[] = {outLoweredType, type}; info.convertOriginalToLowered.func->setFullType( - builder.getFuncType(1, (IRType**)&type, loweredType)); + builder.getFuncType(2, paramTypes, builder.getVoidType())); builder.emitBlock(); + auto outParam = builder.emitParam(outLoweredType); auto param = builder.emitParam(type); List<IRInst*> args; Index fieldId = 0; @@ -709,15 +767,15 @@ struct LoweredElementTypeContext } auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); - auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.apply( + auto destAddr = builder.emitFieldAddress(outParam, field->getKey()); + + fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.applyDestinationDriven( builder, - fieldLoweredTypeInfo[fieldId].loweredType, + destAddr, fieldVal); - args.add(packedField); fieldId++; } - auto result = builder.emitMakeStruct(loweredType, args); - builder.emitReturn(result); + builder.emitReturn(); } return info; @@ -743,37 +801,8 @@ struct LoweredElementTypeContext info.loweredType = builder.getVectorType( info.loweredType, vectorType->getElementCount()); - // Create unpack func. - { - builder.setInsertAfter(type); - info.convertLoweredToOriginal = builder.createFunc(); - builder.setInsertInto(info.convertLoweredToOriginal.func); - builder.addNameHintDecoration( - info.convertLoweredToOriginal.func, - UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal.func->setFullType( - builder.getFuncType(1, (IRType**)&info.loweredType, type)); - builder.emitBlock(); - auto loweredParam = builder.emitParam(info.loweredType); - auto result = builder.emitCast(type, loweredParam); - builder.emitReturn(result); - } - - // Create pack func. - { - builder.setInsertAfter(info.convertLoweredToOriginal.func); - info.convertOriginalToLowered = builder.createFunc(); - builder.setInsertInto(info.convertOriginalToLowered.func); - builder.addNameHintDecoration( - info.convertOriginalToLowered.func, - UnownedStringSlice("packStorage")); - info.convertOriginalToLowered.func->setFullType( - builder.getFuncType(1, (IRType**)&type, info.loweredType)); - builder.emitBlock(); - auto param = builder.emitParam(type); - auto result = builder.emitCast(info.loweredType, param); - builder.emitReturn(result); - } + info.convertLoweredToOriginal = kIROp_BuiltinCast; + info.convertOriginalToLowered = kIROp_BuiltinCast; return info; } } @@ -828,7 +857,8 @@ struct LoweredElementTypeContext IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) { if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || - as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType)) + as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType) || + as<IRGLSLShaderStorageBufferType>(originalPtrLikeType)) { IRBuilder builder(newElementType); builder.setInsertAfter(newElementType); @@ -859,6 +889,26 @@ struct LoweredElementTypeContext TypeLoweringConfig config; }; + IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst) + { + switch (loadStoreInst->getOp()) + { + case kIROp_Load: + case kIROp_Store: + return loadStoreInst->getOperand(0); + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_RWStructuredBufferStore: + return builder.emitRWStructuredBufferGetElementPtr( + loadStoreInst->getOperand(0), + loadStoreInst->getOperand(1)); + default: + return nullptr; + } + } + void processModule(IRModule* module) { IRBuilder builder(module); @@ -891,6 +941,8 @@ struct LoweredElementTypeContext elementType = structBuffer->getElementType(); else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) elementType = constBuffer->getElementType(); + else if (auto storageBuffer = as<IRGLSLShaderStorageBufferType>(globalInst)) + elementType = storageBuffer->getElementType(); if (as<IRTextureBufferType>(globalInst)) continue; if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && @@ -908,6 +960,10 @@ struct LoweredElementTypeContext { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; + + if (elementType->findDecoration<IRPhysicalTypeDecoration>()) + continue; + auto config = getTypeLoweringConfigForBuffer(target, bufferType); auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config); @@ -954,8 +1010,13 @@ struct LoweredElementTypeContext // getOffsetPtr(trailingPtr, index). if (auto fieldAddr = as<IRFieldAddress>(ptrVal)) { - if (auto ptrType = as<IRPtrType>(ptrVal->getDataType())) + auto handleUnsizedArrayAccess = [&]() -> bool { + auto ptrType = as<IRPtrType>(ptrVal->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return false; if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType())) { builder.setInsertBefore(ptrVal); @@ -1019,9 +1080,12 @@ struct LoweredElementTypeContext }); SLANG_ASSERT(!ptrVal->hasUses()); ptrVal->removeAndDeallocate(); - continue; + return true; } - } + return false; + }; + if (handleUnsizedArrayAccess()) + continue; } LoweredElementTypeInfo loweredElementTypeInfo = {}; @@ -1060,7 +1124,7 @@ struct LoweredElementTypeContext getLoweredTypeInfo((IRType*)originalElementType, config); } - if (!loweredElementTypeInfo.convertLoweredToOriginal) + if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) continue; ptrVal->setFullType(getLoweredPtrLikeType( @@ -1083,15 +1147,28 @@ struct LoweredElementTypeContext case kIROp_RWStructuredBufferLoadStatus: case kIROp_StructuredBufferConsume: { - IRCloneEnv cloneEnv = {}; builder.setInsertBefore(user); - auto newLoad = cloneInst(&cloneEnv, &builder, user); - newLoad->setFullType(loweredElementTypeInfo.loweredType); + auto addr = getBufferAddr(builder, user); + if (!addr) + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + builder.emitStore(addr, newLoad); + } + if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) + { + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); + } auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( builder, loweredElementTypeInfo.originalType, - newLoad); + addr); user->replaceUsesWith(unpackedVal); user->removeAndDeallocate(); break; @@ -1106,19 +1183,33 @@ struct LoweredElementTypeContext IRCloneEnv cloneEnv = {}; builder.setInsertBefore(user); auto originalVal = getStoreVal(user); - auto packedVal = - loweredElementTypeInfo.convertOriginalToLowered.apply( - builder, - loweredElementTypeInfo.loweredType, - originalVal); - if (auto store = as<IRStore>(user)) - store->val.set(packedVal); - else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) - sbStore->setOperand(2, packedVal); + IRInst* addr = getBufferAddr(builder, user); + if (addr) + { + if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) + { + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); + } + + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); + user->removeAndDeallocate(); + } else if (auto sbAppend = as<IRStructuredBufferAppend>(user)) + { + builder.setInsertBefore(sbAppend); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); + auto packedVal = builder.emitLoad(addr); sbAppend->setOperand(1, packedVal); + } else + { SLANG_UNREACHABLE("unhandled store type"); + } break; } case kIROp_GetElementPtr: @@ -1176,24 +1267,18 @@ struct LoweredElementTypeContext // access, we need to materialize the object as a local variable, // and pass the address of the local variable to the function. builder.setInsertBefore(user); - auto newLoad = - builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal); auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( builder, (IRType*)originalElementType, - newLoad); + ptrVal); auto var = builder.emitVar((IRType*)originalElementType); builder.emitStore(var, unpackedVal); use->set(var); builder.setInsertAfter(user); auto newVal = builder.emitLoad(var); - auto packedVal = - loweredElementTypeInfo.convertOriginalToLowered.apply( - builder, - (IRType*)loweredElementTypeInfo.loweredType, - newVal); - builder.emitStore(ptrVal, packedVal); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, ptrVal, newVal); } break; default: @@ -1355,6 +1440,21 @@ void lowerBufferElementTypeToStorageType( context.processModule(module); } +IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout) +{ + switch (layoutTypeOp) + { + case kIROp_DefaultBufferLayoutType: + return defaultLayout; + case kIROp_Std140BufferLayoutType: + return IRTypeLayoutRules::getStd140(); + case kIROp_Std430BufferLayoutType: + return IRTypeLayoutRules::getStd430(); + case kIROp_ScalarBufferLayoutType: + return IRTypeLayoutRules::getNatural(); + } + return defaultLayout; +} IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) { @@ -1395,18 +1495,7 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = structBufferType->getDataLayout() ? structBufferType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - switch (layoutTypeOp) - { - case kIROp_DefaultBufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); - } - return IRTypeLayoutRules::getStd430(); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); } case kIROp_ConstantBufferType: case kIROp_ParameterBlockType: @@ -1416,18 +1505,15 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = parameterGroupType->getDataLayout() ? parameterGroupType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - switch (layoutTypeOp) - { - case kIROp_DefaultBufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); - } - return IRTypeLayoutRules::getStd140(); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140()); + } + case kIROp_GLSLShaderStorageBufferType: + { + auto storageBufferType = as<IRGLSLShaderStorageBufferType>(bufferType); + auto layoutTypeOp = storageBufferType->getDataLayout() + ? storageBufferType->getDataLayout()->getOp() + : kIROp_Std430BufferLayoutType; + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); } case kIROp_PtrType: return IRTypeLayoutRules::getNatural(); @@ -1446,6 +1532,9 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* case AddressSpace::Output: addrSpace = AddressSpace::Input; break; + case AddressSpace::UserPointer: + addrSpace = AddressSpace::UserPointer; + break; } } auto rules = getTypeLayoutRuleForBuffer(target, bufferType); |
