diff options
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 1507 |
1 files changed, 1110 insertions, 397 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 056ee6244..128502bd8 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -6,6 +6,211 @@ #include "slang-ir-util.h" #include "slang-ir.h" +/// This file implements an important IR transformation pass in the Slang compiler +/// that rewrites buffer element types into valid storage types, a.k.a physical types +/// in SPIRV terminology. +/// +/// Many of our targets have special restrictions on what is allowed to be used as a +/// buffer element. Examples are: +/// - In HLSL and SPIRV, if you have ConstantBuffer<T>, T must be a struct. +/// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside +/// buffers. bool vectors and matrices needs to be lowered into arrays. +/// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have +/// explicit offset. But if it is used to declare a local variable, then it cannot +/// have explicit member offset. This means that we cannot use the same `Foo` struct +/// inside a `StructuredBuffer<Foo>` and also use it to declare a local variable. +/// +/// We use the terms "physical", "storage", or "lowered" types to refer to types that +/// are legal to use as buffer elements. In contrast, the terms "original" or "logical" +/// refers to types that are declared by the user in its original form. +/// For example, `bool4` is a "logical" type, and its lowered type is `int4`. +/// +/// +/// # Algorithm Overview +/// ---------------------- +/// +/// This pass performs the transformation to create one "storage" type for each type that +/// are used in each kind of buffer. For example, if user defined `Foo`, and used it in +/// `ConstantBuffer<Foo>` and `StructuredBuffer<Foo>` and is targeting SPIRV, this pass will +/// create `Foo_std140` and `Foo_std430` types, and update the buffer to be +/// `ConstantBuffer<Foo_std140>` and `StructuredBuffer<Foo_std430>`. +/// +/// The pass will rewrite all the code that uses this buffers, and insert translations between +/// Foo_std140/Foo_std430 and Foo to keep types consistent. +/// +/// For example, given: +/// ``` +/// struct Foo { +/// bool4x4 v; +/// } +/// ConstantBuffer<Foo> cb; +/// bool test(Foo f) { +/// return f.v[0][1]; +/// } +/// void main() { test(cb); } +/// ``` +/// +/// This pass will rewrite it as: +/// ``` +/// struct Foo { +/// bool4x4 v; +/// } +/// struct Foo_std140 { +/// Matrix_bool4x4_std140 v; +/// }; +/// struct Matrix_bool4x4_std140 { +/// int4 values[4]; +/// }; +/// ConstantBuffer<Foo_std140> cb; +/// bool test_1(Foo_std140 f) { +/// return f.v.values[0][1]; +/// } +/// void main() { test_1(cb); } +/// ``` +/// +/// Note that the one important optimization here is we will defer the translation from +/// storage type to logical type at latest possible time. In the example above, we could +/// have loaded `cb` and then immediately translate it into `Foo` and call `test` with +/// the translated value. However that can lead to code that create unnecessary copies +/// that can't always be removed by the downstream compiler, particulary if there are +/// arrays whose element type needs non-trivial translation. +/// +/// To avoid the performance issue, we will defer this translation until a logical value +/// is actually needed. This is done by pushing the translation to the use sites, and +/// across function call boundaries, specializing any functions being called along the +/// chain. This case, since we are calling `test()` from `main()` with `Foo_std140`, instead +/// of converting the `Foo_std140` to `Foo` before the call, we create a specialization +/// of `test` that accepts `Foo_std140` instead. +/// +/// To enable this interprecedural transformation, the pass is organized as two phases: +/// 1. Create lowered / storage types for all buffer element types, and update +/// global buffer declarations to use storage types. This is implemented in `processModule()` +/// 2. Insert a `CastStorageToLogical(loweredBuffer)` inst, and replace all uses of +/// `loweredBuffer` with the cast inst. This is implemented in `processModule()` +/// 3. Push the `CastStorageToLogical` insts to as late as possible, which means if we see +/// `FieldAddress(CastStorageToLogical(storageAddr), memberKey)`, we should translate +/// it into `CastStorageToLogical(FieldAddress(storageAddr, memberKey)`. +/// If we see a `CastStorageToLogical` inst being used as argument to call a function `f`, +/// specialize `f` to take a pointer to the storage type instead, and insert a +/// `CastStorageToLogical(param)` to convert the param type to logical type at the +/// beginning of the specialized function. (implemented in `deferStorageToLogicalCasts()`) +/// +/// Repeat step 2 and 3 until no more changes can be made, then proceed to step 4. +/// +/// 4. Materialize all remaining `CastStorageToLogical(addr)` by replacing all `load` of such +/// cast insts with `call unpackStorage(addr)`, where `unpackStorage` is a function we +/// synthesize that reads from an address of a storage type and returns a logical type; +/// and replacing all `store(CastStorageToLogical(addr), value)` with `packStorage(addr, value)`, +/// where `packStorage` is a function we synthesis that writes a logical value into a storage +/// addr. This is implemented in `materializeStorageToLogicalCasts()`. +/// +/// That's the main idea of the pass. +/// +/// # Propagating through SSA values +/// +/// Note that `kIROp_CastStorageToLogical` is a pseudo instruction introduced in this pass that +/// has the semantics of "converting a pointer to a storage value into a pointer to a logical +/// value". A dual of this inst is `kIROp_CastStorageToLogicalDeref`, which has an additional +/// builtin "load" semantic. That is, given `Ptr<StorageType> addr`, `CastStorageToLogical(addr)` +/// will have type `Ptr<LogicalType>`, and `CastStorageToLogicalDeref(addr)` will have type +/// `LogicalType`. In other words, `CastStorageToLogicalDeref(addr)` is equivalent to +/// `load(CastStorageToLogical(addr))`. +/// +/// The `CastStorageToLogicalDeref` pseudo inst is needed to push defer through `load`s. +/// Consider the following example: +/// ``` +/// ptr : StorageType* = ... +/// lptr : LogicalType* = CastStorageToLogical(ptr); +/// l = load(lptr) +/// m = fieldExtract(l, member) +/// call f, m +/// ``` +/// In this case, only l.member is used, so we should avoid translating other unrelated members +/// from storage type to logical type. To achieve this we must be able to push the +/// `CastStorageToLogical` operation beyond the `load`. The steps to achieve this are: +/// 1. we process `lptr` inst by inspecting its users. We find that a `load` (l) uses it. +/// 2. replace the `load` with `CastStorageToLogicalDeref(ptr)`, the IR become: +/// ``` +/// ptr : StorageType* = ... +/// l_1 = CastStorageToLogicalDeref(ptr); +/// m = fieldExtract(l_1, member); +/// call f, m +/// ``` +/// 3. push the new `l_1` inst to worklist, and when it gets processed, we continue to inspect +/// its users, and find that it is being used by `fieldExtract`. We will rewrite the +/// `fieldExtract` into `CastStorageToLogicalDeref(fieldAddr(ptr, member))`, and the IR become: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogicalDeref(m_ptr); +/// call f, m +/// ``` +/// 4. Since there are no more uses of `m` that can be translated, stop. Note that it is possible +/// to continue specializing `f` and replace its first parameter's type to storage type. However +/// this implementation currently does not specialize functions whose parameter type is not a +/// pointer/reference type. When we target SPIRV, we will already be running the +/// `transformParamsToConstRef` pass that would have converted `f` to take in `ConstRef<T>`. +/// In this case, the initial IR would be in the form of +/// ``` +/// ptr : StorageType* = ... +/// lptr : LogicalType* = CastStorageToLogical(ptr); +/// l = load(lptr) +/// m = fieldExtract(l, member) +/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]] +/// store tmpVar, m +/// call f, tmpVar +/// ``` +/// To allow us to remove the `tmpVar` store introduced during `transformParamsToConstRef`, +/// this pass also handles the propagation through temp var stores. After pushing the cast +/// through `m`, we will get IR to this form: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogicalDeref(m_ptr); +/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]] +/// store tmpVar, m +/// call f, tmpVar +/// ``` +/// This time, we will see that `m` is being used by a `store` into a `[[ImmutableTempVar]]` var, +/// and we can safely replace all uses of `tmpVar` to `m_ptr`, and therefore the IR will become: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogical(m_ptr); +/// call f, m_ptr +/// ``` +/// Now, we are in the case where a `CastStorageToLogical` is used as argument in a `call`. +/// This will trigger our function specialization rule to create `f_1` that accepets a +/// `StorageMember*`, and we will rewrite the IR again to: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// call f_1, m_ptr +/// ``` +/// +/// # Trailing Pointer Rewrite +/// +/// Another transformation done in this pass is it also rewrites struct with unsized trailing +/// arrays. Since an unsized type isn't a physical type and cannot be used as a pointee type, +/// we will have problem translating the following code to SPIRV: +/// ``` +/// struct Foo { int count; int[] values; } +/// uniform Foo* b; +/// ``` +/// +/// When we create a storage type for `Foo`, we will define it as: +/// ``` +/// struct Foo_std430 { int count; } +/// ``` +/// Where we removed the trailing array. +/// This makes `Foo_std430` an ordinary sized type that can be used freely as pointee type +/// in SPIRV. +/// +/// However this does mean that we also need to translate things like `ptr->values[2]` +/// into `((int*)(ptr+1))[2]`. Which we also handle during step 2 of the algorithm. +/// (`maybeTranslateTrailingPointerGetElementAddress`) +/// + namespace Slang { @@ -85,7 +290,7 @@ struct LoweredElementTypeContext else { auto val = builder.emitIntrinsicInst( - tryGetPointedToType(&builder, dest->getDataType()), + tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), op, 1, &operand); @@ -145,6 +350,22 @@ struct LoweredElementTypeContext TargetProgram* target; BufferElementTypeLoweringOptions options; + struct SpecializationKey + { + IRFunc* callee; + IRFuncType* specializedFuncType; + bool operator==(const SpecializationKey& other) const + { + return (callee == other.callee && specializedFuncType == other.specializedFuncType); + } + HashCode64 getHashCode() const + { + return combineHash(Slang::getHashCode(callee), Slang::getHashCode(specializedFuncType)); + } + }; + // Specialized functions that takes storage-typed pointers instead of logical-typed pointers. + Dictionary<SpecializationKey, IRFunc*> specializedFuncs; + LoweredElementTypeContext( TargetProgram* target, BufferElementTypeLoweringOptions inOptions, @@ -881,17 +1102,25 @@ struct LoweredElementTypeContext IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) { - if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || + IRBuilder builder(newElementType); + builder.setInsertAfter(newElementType); + if (auto ptrType = as<IRPtrTypeBase>(originalPtrLikeType)) + { + return builder.getPtrType(newElementType, ptrType); + } + + if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType) || as<IRGLSLShaderStorageBufferType>(originalPtrLikeType)) { - IRBuilder builder(newElementType); - builder.setInsertAfter(newElementType); ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPtrLikeType->getOperandCount(); i++) + operands.add(newElementType); + for (UInt i = 1; i < originalPtrLikeType->getOperandCount(); i++) + { operands.add(originalPtrLikeType->getOperand(i)); - operands[0] = newElementType; - return builder.getType( + } + return (IRType*)builder.emitIntrinsicInst( + builder.getTypeKind(), originalPtrLikeType->getOp(), (UInt)operands.getCount(), operands.getArrayView().getBuffer()); @@ -914,26 +1143,544 @@ struct LoweredElementTypeContext TypeLoweringConfig config; }; - IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst) + IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst, IRInst* baseAddr) { switch (loadStoreInst->getOp()) { case kIROp_Load: case kIROp_Store: - return loadStoreInst->getOperand(0); + return baseAddr; case kIROp_StructuredBufferLoad: case kIROp_StructuredBufferLoadStatus: case kIROp_RWStructuredBufferLoad: case kIROp_RWStructuredBufferLoadStatus: case kIROp_RWStructuredBufferStore: return builder.emitRWStructuredBufferGetElementPtr( - loadStoreInst->getOperand(0), + baseAddr, loadStoreInst->getOperand(1)); default: return nullptr; } } + bool maybeTranslateTrailingPointerGetElementAddress( + IRBuilder& builder, + IRFieldAddress* fieldAddr, + IRCastStorageToLogicalBase* castInst, + TypeLoweringConfig& config, + List<IRCastStorageToLogicalBase*>& castInstWorkList) + { + // If we are accessing an unsized array element from a pointer, we need to + // compute + // the trailing ptr that points to the first element of the array. + // And then replace all getElementPtr(arrayPtr, index) with + // getOffsetPtr(trailingPtr, index). + + auto ptrType = as<IRPtrTypeBase>(fieldAddr->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return false; + if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType())) + { + builder.setInsertBefore(fieldAddr); + auto newArrayPtrVal = fieldAddr->getBase(); + auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), config); + + IRSizeAndAlignment arrayElementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerType.loweredType, + &arrayElementSizeAlignment); + IRSizeAndAlignment baseSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()), + &baseSizeAlignment); + + // Convert pointer to uint64 and adjust offset. + IRIntegerValue offset = baseSizeAlignment.size; + offset = align(offset, arrayElementSizeAlignment.alignment); + if (offset != 0) + { + auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); + newArrayPtrVal = builder.emitAdd( + rawPtr->getFullType(), + rawPtr, + builder.getIntValue(builder.getUInt64Type(), offset)); + } + newArrayPtrVal = builder.emitBitCast( + builder.getPtrType(loweredInnerType.loweredType, ptrType), + newArrayPtrVal); + traverseUses( + fieldAddr, + [&](IRUse* fieldAddrUse) + { + auto fieldAddrUser = fieldAddrUse->getUser(); + if (fieldAddrUser->getOp() == kIROp_GetElementPtr) + { + builder.setInsertBefore(fieldAddrUser); + auto newElementPtr = + builder.emitGetOffsetPtr(newArrayPtrVal, fieldAddrUser->getOperand(1)); + auto castedGEP = builder.emitCastStorageToLogical( + fieldAddrUser->getFullType(), + newElementPtr, + castInst->getBufferType()); + fieldAddrUser->replaceUsesWith(castedGEP); + fieldAddrUser->removeAndDeallocate(); + if (auto castStorage = as<IRCastStorageToLogicalBase>(castedGEP)) + castInstWorkList.add(castStorage); + } + else if (fieldAddrUser->getOp() == kIROp_GetOffsetPtr) + { + } + else + { + SLANG_UNEXPECTED("unknown use of pointer to unsized array."); + } + }); + SLANG_ASSERT(!fieldAddr->hasUses()); + fieldAddr->removeAndDeallocate(); + return true; + } + return false; + } + + + // Helper function to discover all `call`s in `func` that has at least one argument + // that is `CastStorageToPhysical`. + void discoverCallsToProcess(List<IRCall*>& callWorkList, IRFunc* func) + { + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + auto call = as<IRCall>(inst); + if (!call) + continue; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (arg->getOp() == kIROp_CastStorageToLogical) + { + callWorkList.add(call); + break; + } + } + } + } + } + + void deferStorageToLogicalCasts( + IRModule* module, + List<IRCastStorageToLogicalBase*> castInstWorkList) + { + IRBuilder builder(module); + + while (castInstWorkList.getCount()) + { + // We process call instructions after other instructions, so we + // can be sure that all castStorageToLogical insts have already + // been pushed to the call argument lists before we process it. + HashSet<IRCall*> callWorkListSet; + // Defer the storage-to-logical cast operation to latest possible time to avoid + // unnecessary packing/unpacking. + for (Index i = 0; i < castInstWorkList.getCount(); i++) + { + auto castInst = castInstWorkList[i]; + auto ptrVal = castInst->getOperand(0); + auto config = + getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType()); + traverseUses( + castInst, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_FieldAddress: + if (!isUseBaseAddrOperand(use, user)) + break; + // If our logical struct type ends with an unsized array field, the + // storage struct type won't have this field defined. + // Therefore, all fieldAddress(obj, lastField) inst retrieving the last + // field of such struct should be translated into + // `(ArrayElementType*)((StorageStruct*)(obj)+1) + idx`. + // That is, we should first compute the tailing pointer of the + // struct, and replace all getElementPtr(fieldAddr, idx) with + // getOffsetPtr(tailingPtr, idx). + if (maybeTranslateTrailingPointerGetElementAddress( + builder, + (IRFieldAddress*)user, + castInst, + config, + castInstWorkList)) + return; + [[fallthrough]]; + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_RWStructuredBufferGetElementPtr: + { + // gep(castStorageToLogical(x)) ==> castStorageToLogical(gep(x)) + if (!isUseBaseAddrOperand(use, user)) + break; + auto logicalBaseType = castInst->getDataType(); + auto logicalType = user->getDataType(); + IRInst* storageBaseAddr = ptrVal; + auto originalBaseValueType = + tryGetPointedToOrBufferElementType(&builder, logicalBaseType); + if (user->getOp() == kIROp_GetElementPtr) + { + // If original type is an array, the lowered type will be a + // struct. In that case, all existing address insts should be + // appended with a field extract. + if (as<IRArrayType>(originalBaseValueType)) + { + auto arrayLowerInfo = + getLoweredTypeInfo(originalBaseValueType, config); + if (arrayLowerInfo.loweredInnerArrayType) + { + builder.setInsertBefore(user); + List<IRInst*> args; + for (UInt i = 0; i < user->getOperandCount(); i++) + args.add(user->getOperand(i)); + storageBaseAddr = builder.emitFieldAddress( + builder.getPtrType( + arrayLowerInfo.loweredInnerArrayType), + ptrVal, + arrayLowerInfo.loweredInnerStructKey); + } + } + if (as<IRMatrixType>(originalBaseValueType)) + { + // We are tring to get a pointer to a lowered matrix + // element. We process this insts at a later phase. + SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); + lowerMatrixAddresses( + module, + MatrixAddrWorkItem{user, config}); + break; + } + } + + + builder.setInsertBefore(user); + IRInst* storageGEP = nullptr; + switch (user->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + // For standard gep instructions, use the + // IR builder to auto-deduce result type + // of the new GEP inst. + ShortList<IRInst*> newArgs; + for (UInt i = 1; i < user->getOperandCount(); i++) + newArgs.add(user->getOperand(i)); + storageGEP = builder.emitElementAddress( + storageBaseAddr, + newArgs.getArrayView().arrayView); + break; + } + default: + { + // For non-standard gep instructions, e.g. + // RWStructuredBufferGetElementPtr, + // manually create the inst here. + ShortList<IRInst*> newArgs; + newArgs.add(storageBaseAddr); + for (UInt i = 1; i < user->getOperandCount(); i++) + newArgs.add(user->getOperand(i)); + auto logicalValueType = tryGetPointedToOrBufferElementType( + &builder, + logicalType); + auto storageTypeInfo = + getLoweredTypeInfo(logicalValueType, config); + storageGEP = builder.emitIntrinsicInst( + builder.getPtrType(storageTypeInfo.loweredType), + user->getOp(), + newArgs.getCount(), + newArgs.getArrayView().getBuffer()); + break; + } + } + auto castOfGEP = builder.emitCastStorageToLogical( + logicalType, + storageGEP, + castInst->getBufferType()); + user->replaceUsesWith(castOfGEP); + user->removeAndDeallocate(); + if (auto castStorage = as<IRCastStorageToLogical>(castOfGEP)) + castInstWorkList.add(castStorage); + break; + } + case kIROp_Call: + { + // call(f, castStorageToLogical(x)) ==> call(f', x) + // + // If we see a call that takes a logical typed pointer, we will + // specialize the callee to take a storage typed pointer instead, + // and push the cast to inside the callee. + // We will process calls after other gep insts, so for now just add + // it into a separate worklist. + if (castInst->getOp() == kIROp_CastStorageToLogical) + { + callWorkListSet.add((IRCall*)user); + } + break; + } + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_StructuredBufferConsume: + { + // If we see a load(CastStorageToLogical(storageAddr)), + // then based on what `storageAddr` is, we will push down + // the cast differently. + // - If `storageAddr` is already a tempVar that we introduced to + // hold the value of a buffer resource load, we can simply + // convert this into `CastStorageToLogicalDeref(storageAddr)`. + // - Otherwise, if `storageAddr` is a buffer location, we will + // create a temp var to hold the result of the memory load, + // Then we create a `CastStorageToLogicalDeref(tempVar)` + // structure and use it to replace `user`. + // Note that it is important to introduce a temp var and preserve + // the buffer load operation, so we are not changing the memory + // semantics of the original program. + if (!isUseBaseAddrOperand(use, user)) + break; + // If loaded value is itself a pointer or buffer, + // stop pushing the cast along the resulting address. + // we will handle loads from the pointer separately. + if (as<IRPointerLikeType>(user->getDataType()) || + as<IRPtrTypeBase>(user->getDataType()) || + as<IRHLSLStructuredBufferTypeBase>(user->getDataType())) + break; + // Don't push the cast beyond the load if we are already + // a simple type. + if (!isCompositeType(user->getDataType())) + break; + builder.setInsertBefore(user); + IRCloneEnv cloneEnv; + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setOperand(0, ptrVal); + auto elementStorageType = tryGetPointedToOrBufferElementType( + &builder, + ptrVal->getDataType()); + newLoad->setFullType(elementStorageType); + IRInst* tempVar = nullptr; + if (as<IRLoad>(user)) + { + auto rootAddr = getRootAddr(ptrVal); + if (rootAddr->findDecorationImpl( + kIROp_TempCallArgImmutableVarDecoration)) + tempVar = ptrVal; + } + if (!tempVar) + { + tempVar = builder.emitVar(elementStorageType); + builder.addDecoration( + tempVar, + kIROp_TempCallArgImmutableVarDecoration); + builder.emitStore(tempVar, newLoad); + } + auto newCast = builder.emitCastStorageToLogicalDeref( + user->getFullType(), + tempVar, + castInst->getBufferType()); + user->replaceUsesWith(newCast); + user->removeAndDeallocate(); + castInstWorkList.add(newCast); + break; + } + case kIROp_FieldExtract: + case kIROp_GetElement: + { + if (!isUseBaseAddrOperand(use, user)) + break; + // elementExtract(castStorageToLogicalDeref(addr), key) + // ==> load(gep(castStorageToLogical(addr), key) + builder.setInsertBefore(user); + auto castAddr = builder.emitCastStorageToLogical( + builder.getPtrType(castInst->getDataType()), + ptrVal, + castInst->getBufferType()); + IRInst* gep = nullptr; + if (user->getOp() == kIROp_GetElement) + gep = builder.emitElementAddress(castAddr, user->getOperand(1)); + else + gep = builder.emitFieldAddress(castAddr, user->getOperand(1)); + auto load = builder.emitLoad(gep); + user->replaceUsesWith(load); + user->removeAndDeallocate(); + if (auto castStorage = as<IRCastStorageToLogical>(castAddr)) + castInstWorkList.add(castStorage); + break; + } + case kIROp_Store: + { + // If we see `store(tempVar, castStorageToLogicalDeref(addr))`, + // replace `tempVar` with `castStorageToLogical(addr)`. + if (castInst->getOp() != kIROp_CastStorageToLogicalDeref) + break; + auto store = as<IRStore>(user); + if (store->getVal() != castInst) + break; + auto dest = store->getPtr(); + if (!dest->findDecorationImpl( + kIROp_TempCallArgImmutableVarDecoration)) + break; + builder.setInsertBefore(user); + auto castAddr = builder.emitCastStorageToLogical( + builder.getPtrType(castInst->getDataType()), + ptrVal, + castInst->getBufferType()); + dest->replaceUsesWith(castAddr); + dest->removeAndDeallocate(); + if (auto castStorage = as<IRCastStorageToLogical>(castAddr)) + castInstWorkList.add(castStorage); + break; + } + } + }); + } + + // Now that we have processed all GEP instructions, we can now proceed to + // process all calls. This is done by making a clone of the callee, and change + // the parameter type from logical type to storage type, and insert a + // castStorageToLogical on the parameter. Then we go back to the beginning and make sure + // we process those newly created castStorageToLogical insts. + List<IRCastStorageToLogicalBase*> newCasts; + List<IRCall*> callWorkList; + for (auto call : callWorkListSet) + callWorkList.add(call); + for (Index c = 0; c < callWorkList.getCount(); c++) + { + auto call = callWorkList[c]; + auto calleeFunc = as<IRGlobalValueWithParams>(call->getCallee()); + // We compute the func type for the specialized func based on the arguments + // provided, and check the specialization cache to reuse existing specialization + // when possible. + List<IRInst*> oldParams; + for (auto param : calleeFunc->getParams()) + oldParams.add(param); + SLANG_ASSERT(oldParams.getCount() == (Index)call->getArgCount()); + + ShortList<IRType*> paramTypes; + ShortList<IRInst*> newArgs; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (auto castArg = as<IRCastStorageToLogical>(arg)) + { + auto oldParamPtrType = oldParams[i]->getDataType(); + auto storageValueType = tryGetPointedToOrBufferElementType( + &builder, + castArg->getOperand(0)->getDataType()); + auto storagePtrType = + getLoweredPtrLikeType(oldParamPtrType, storageValueType); + paramTypes.add(storagePtrType); + newArgs.add(castArg->getOperand(0)); + } + else + { + paramTypes.add(arg->getDataType()); + newArgs.add(arg); + } + } + auto specializedFuncType = builder.getFuncType( + (UInt)paramTypes.getCount(), + paramTypes.getArrayView().getBuffer(), + call->getDataType()); + auto key = SpecializationKey{(IRFunc*)calleeFunc, specializedFuncType}; + IRFunc* specializedFunc = nullptr; + if (!specializedFuncs.tryGetValue(key, specializedFunc)) + { + specializedFunc = createSpecializedFuncThatUseStorageType( + call, + specializedFuncType, + newCasts); + specializedFuncs[key] = specializedFunc; + + // The cloned function may also contain `call`s with + // `CastStorageToLogical` arguments, and we want to add + // thoses calls to the callWorkList for further processing. + discoverCallsToProcess(callWorkList, specializedFunc); + } + builder.setInsertBefore(call); + auto newCall = builder.emitCallInst( + call->getFullType(), + specializedFunc, + newArgs.getArrayView().arrayView); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + + // Remove any casts that have no more uses. + for (auto cast : castInstWorkList) + { + if (!cast->hasUses()) + cast->removeAndDeallocate(); + } + + // Continue to process new casts added during function specialization. + castInstWorkList.swapWith(newCasts); + } + } + + IRFunc* createSpecializedFuncThatUseStorageType( + IRCall* call, + IRFuncType* specializedFuncType, + List<IRCastStorageToLogicalBase*>& outNewCasts) + { + IRBuilder builder(call); + builder.setInsertBefore(call->getCallee()); + + // Create a clone of the callee. + IRCloneEnv cloneEnv; + auto clonedFunc = as<IRFunc>(cloneInst(&cloneEnv, &builder, call->getCallee())); + List<IRUse*> uses; + + // If a parameter is being translated to storage type, + // insert a cast to convert it to logical type. + List<IRParam*> params; + for (auto param : clonedFunc->getParams()) + params.add(param); + for (UInt i = 0; i < (UInt)params.getCount(); i++) + { + auto param = params[i]; + SLANG_RELEASE_ASSERT(i < call->getArgCount()); + auto arg = call->getArg(i); + auto cast = as<IRCastStorageToLogical>(arg); + if (!cast) + continue; + auto logicalParamType = param->getFullType(); + auto storageType = specializedFuncType->getParamType(i); + param->setFullType((IRType*)storageType); + setInsertAfterOrdinaryInst(&builder, param); + + // Store uses of param before creating a cast inst that uses it. + uses.clear(); + for (auto use = param->firstUse; use; use = use->nextUse) + uses.add(use); + auto castedParam = + builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType()); + if (auto castStorage = as<IRCastStorageToLogicalBase>(castedParam)) + outNewCasts.add(castStorage); + + // Replace all previous uses of param to use castedParam instead. + for (auto use : uses) + builder.replaceOperand(use, castedParam); + } + clonedFunc->setFullType(specializedFuncType); + removeLinkageDecorations(clonedFunc); + return clonedFunc; + } + void processModule(IRModule* module) { IRBuilder builder(module); @@ -941,6 +1688,7 @@ struct LoweredElementTypeContext { IRType* bufferType; IRType* elementType; + IRType* loweredBufferType = nullptr; bool shouldWrapArrayInStruct = false; }; List<BufferTypeInfo> bufferTypeInsts; @@ -990,12 +1738,10 @@ struct LoweredElementTypeContext bufferTypeInsts.add(BufferTypeInfo{(IRType*)globalInst, elementType}); } - // Maintain a pending work list of all matrix addresses, and try to lower them out of - // existance after everything else has been lowered. - List<MatrixAddrWorkItem> matrixAddrInsts; + List<IRCastStorageToLogicalBase*> castInstWorkList; - for (auto bufferTypeInfo : bufferTypeInsts) + for (auto& bufferTypeInfo : bufferTypeInsts) { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; @@ -1022,10 +1768,10 @@ struct LoweredElementTypeContext (UInt)typeOperands.getCount(), typeOperands.getArrayView().getBuffer()); - // We treat a value of a buffer type as a pointer, and use a work list to translate - // all loads and stores through the pointer values that needs lowering. + // Replace all global buffer declarations to use the storage type instead, + // and insert initial `castStorageToLogical` instructions to convert the + // storage-typed pointer to logical-typed pointer. - List<IRInst*> ptrValsWorkList; traverseUses( bufferType, [&](IRUse* use) @@ -1033,433 +1779,400 @@ struct LoweredElementTypeContext auto user = use->getUser(); if (use != &user->typeUse) return; - ptrValsWorkList.add(use->getUser()); + // We don't want to insert cast instructions for uses of + // intermediate address instruction that are themselves + // derived from some other base address. We will let + // the later part of the pass to systematically propagate + // the cast through them. + switch (user->getOp()) + { + case kIROp_FieldAddress: + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_RWStructuredBufferGetElementPtr: + return; + } + auto ptrVal = use->getUser(); + setInsertAfterOrdinaryInst(&builder, ptrVal); + builder.replaceOperand(use, loweredBufferType); + auto logicalBufferType = getLoweredPtrLikeType(bufferType, elementType); + auto castStorageToLogical = + builder.emitCastStorageToLogical(logicalBufferType, ptrVal, bufferType); + traverseUses( + ptrVal, + [&](IRUse* ptrUse) + { + if (ptrUse->getUser() != castStorageToLogical) + builder.replaceOperand(ptrUse, castStorageToLogical); + }); + if (auto castStorage = as<IRCastStorageToLogical>(castStorageToLogical)) + castInstWorkList.add(castStorage); }); + bufferTypeInfo.loweredBufferType = loweredBufferType; + } + + // Push down `CastStorageToLogical` insts we inserted above to latest possible locations, + // specializing all function calls along the way, until we truly need the the logical value. + // This means that `FieldAddr(CastStorageToLogical(buffer), field0))` is translated to + // `CastStorageToLogical(FieldAddr(buffer, field0))`. This way we can be sure that we are + // doing minimal packing/unpacking. + deferStorageToLogicalCasts(module, _Move(castInstWorkList)); + + // Now translate the `CastStorageToLogical` into actual packing/unpacking code. + materializeStorageToLogicalCasts(module->getModuleInst()); + + // Replace all remaining uses of bufferType to loweredBufferType, these uses are + // non-operational and should be directly replaceable, such as uses in `IRFuncType`. + for (auto bufferTypeInst : bufferTypeInsts) + { + if (!bufferTypeInst.loweredBufferType) + continue; + bufferTypeInst.bufferType->replaceUsesWith(bufferTypeInst.loweredBufferType); + bufferTypeInst.bufferType->removeAndDeallocate(); + } + } - // Translate the values to use new lowered buffer type instead. - for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + void materializeStorageToLogicalCastsImpl(IRCastStorageToLogicalBase* castInst) + { + IRBuilder builder(castInst); + if (!castInst->hasUses()) + { + castInst->removeAndDeallocate(); + return; + } + if (castInst->getOp() == kIROp_CastStorageToLogicalDeref) + { + // Convert CastStorageToLogicalDeref to load(CastStorageToLogical) to reuse + // the same materialization logic for CastStorageToLogical. + // + builder.setInsertBefore(castInst); + auto ptrType = builder.getPtrType(castInst->getDataType()); + auto castPtr = builder.emitCastStorageToLogical( + (IRType*)ptrType, + castInst->getVal(), + castInst->getBufferType()); + auto load = builder.emitLoad(castPtr); + castInst->replaceUsesWith(load); + castInst->removeAndDeallocate(); + if (auto castStorage = as<IRCastStorageToLogical>(castPtr)) + materializeStorageToLogicalCastsImpl(castStorage); + return; + } + + // Translate the values to use new lowered buffer type instead. + + auto ptrVal = castInst->getOperand(0); + auto oldPtrType = castInst->getFullType(); + auto originalElementType = oldPtrType->getOperand(0); + auto config = getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType()); + + + LoweredElementTypeInfo loweredElementTypeInfo = {}; + if (auto getElementPtr = as<IRGetElementPtr>(ptrVal)) + { + if (auto arrayType = as<IRArrayTypeBase>(tryGetPointedToOrBufferElementType( + &builder, + getElementPtr->getBase()->getDataType()))) { - auto ptrVal = ptrValsWorkList[i]; - auto oldPtrType = ptrVal->getFullType(); - auto originalElementType = oldPtrType->getOperand(0); - - // If we are accessing an unsized array element from a pointer, we need to compute - // the trailing ptr that points to the first element of the array. - // And then replace all getElementPtr(arrayPtr, index) with - // getOffsetPtr(trailingPtr, index). - if (auto fieldAddr = as<IRFieldAddress>(ptrVal)) + // For WGSL, an array of scalar or vector type will always be converted to + // an array of 16-byte aligned vector type. In this case, we will run into a + // GetElementPtr where the result type is different from the element type of + // the base array. + // We should setup loweredElementTypeInfo so the remaining logic can handle + // this case and insert proper packing/unpacking logic around it. + if (arrayType->getElementType() != originalElementType && + isScalarOrVectorType(originalElementType)) { - auto handleUnsizedArrayAccess = [&]() -> bool - { - auto ptrType = as<IRPtrType>(ptrVal->getDataType()); - if (!ptrType) - return false; - if (ptrType->getAddressSpace() != AddressSpace::UserPointer) - return false; - if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType())) - { - builder.setInsertBefore(ptrVal); - auto newArrayPtrVal = fieldAddr->getBase(); - auto loweredInnerType = - getLoweredTypeInfo(unsizedArrayType->getElementType(), config); + loweredElementTypeInfo.loweredType = arrayType->getElementType(); + loweredElementTypeInfo.originalType = (IRType*)originalElementType; + loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod( + loweredElementTypeInfo.originalType, + loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod( + loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.originalType); + } + } + } - IRSizeAndAlignment arrayElementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - loweredInnerType.loweredType, - &arrayElementSizeAlignment); - IRSizeAndAlignment baseSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), - &baseSizeAlignment); + // For general cases we simply check if the element type needs lowering. + // If so we will insert packing/unpacking logic if necessary. + // + if (!loweredElementTypeInfo.loweredType) + { + loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, config); + } - // Convert pointer to uint64 and adjust offset. - IRIntegerValue offset = baseSizeAlignment.size; - offset = align(offset, arrayElementSizeAlignment.alignment); - if (offset != 0) - { - auto rawPtr = - builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); - newArrayPtrVal = builder.emitAdd( - rawPtr->getFullType(), - rawPtr, - builder.getIntValue(builder.getUInt64Type(), offset)); - } - newArrayPtrVal = builder.emitBitCast( - builder.getPtrType( - loweredInnerType.loweredType, - ptrType->getAddressSpace()), - newArrayPtrVal); - traverseUses( - ptrVal, - [&](IRUse* use) - { - auto user = use->getUser(); - if (user->getOp() == kIROp_GetElementPtr) - { - builder.setInsertBefore(user); - auto newElementPtr = builder.emitGetOffsetPtr( - newArrayPtrVal, - user->getOperand(1)); - user->replaceUsesWith(newElementPtr); - user->removeAndDeallocate(); - ptrValsWorkList.add(newElementPtr); - } - else if (user->getOp() == kIROp_GetOffsetPtr) - { - } - else - { - SLANG_UNEXPECTED( - "unknown use of pointer to unsized array."); - } - }); - SLANG_ASSERT(!ptrVal->hasUses()); - ptrVal->removeAndDeallocate(); - return true; - } - return false; - }; - if (handleUnsizedArrayAccess()) - continue; - } + if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) + { + castInst->replaceUsesWith(ptrVal); + castInst->removeAndDeallocate(); + return; + } - LoweredElementTypeInfo loweredElementTypeInfo = {}; - if (auto getElementPtr = as<IRGetElementPtr>(ptrVal)) + traverseUses( + castInst, + [&](IRUse* use) + { + auto user = use->getUser(); + if (as<IRDecoration>(user)) + return; + switch (user->getOp()) { - if (auto arrayType = as<IRArrayTypeBase>( - tryGetPointedToType(&builder, getElementPtr->getBase()->getDataType()))) + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_StructuredBufferConsume: { - // For WGSL, an array of scalar or vector type will always be converted to - // an array of 16-byte aligned vector type. In this case, we will run into a - // GetElementPtr where the result type is different from the element type of - // the base array. - // We should setup loweredElementTypeInfo so the remaining logic can handle - // this case and insert proper packing/unpacking logic around it. - if (arrayType->getElementType() != originalElementType && - isScalarOrVectorType(originalElementType)) + if (castInst != user->getOperand(0)) + break; + builder.setInsertBefore(user); + auto addr = getBufferAddr(builder, user, ptrVal); + if (!addr) + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + builder.emitStore(addr, newLoad); + } + if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) { - loweredElementTypeInfo.loweredType = arrayType->getElementType(); - loweredElementTypeInfo.originalType = (IRType*)originalElementType; - loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod( - loweredElementTypeInfo.originalType, - loweredElementTypeInfo.loweredType); - loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod( - loweredElementTypeInfo.loweredType, - loweredElementTypeInfo.originalType); + builder.addAlignedAddressDecoration(addr, alignedAttr->getAlignment()); } + auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( + builder, + loweredElementTypeInfo.originalType, + addr); + user->replaceUsesWith(unpackedVal); + user->removeAndDeallocate(); + return; } - } - - // For general cases we simply check if the element type needs lowering. - // If so we will insert packing/unpacking logic if necessary. - // - if (!loweredElementTypeInfo.loweredType) - { - loweredElementTypeInfo = - getLoweredTypeInfo((IRType*)originalElementType, config); - } - - if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) - continue; - - ptrVal->setFullType(getLoweredPtrLikeType( - ptrVal->getFullType(), - loweredElementTypeInfo.loweredType)); - - traverseUses( - ptrVal, - [&](IRUse* use) + case kIROp_Store: + case kIROp_RWStructuredBufferStore: + case kIROp_StructuredBufferAppend: { - auto user = use->getUser(); - if (as<IRDecoration>(user)) - return; - switch (user->getOp()) + // Use must be the dest operand of the store inst. + if (use != user->getOperands() + 0) + break; + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto originalVal = getStoreVal(user); + if (auto sbAppend = as<IRStructuredBufferAppend>(user)) { - case kIROp_Load: - case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferLoadStatus: - case kIROp_RWStructuredBufferLoad: - case kIROp_RWStructuredBufferLoadStatus: - case kIROp_StructuredBufferConsume: + builder.setInsertBefore(sbAppend); + IRInst* addr = nullptr; + if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref) { - builder.setInsertBefore(user); - auto addr = getBufferAddr(builder, user); - if (!addr) - { - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto newLoad = cloneInst(&cloneEnv, &builder, user); - newLoad->setFullType(loweredElementTypeInfo.loweredType); - addr = builder.emitVar(loweredElementTypeInfo.loweredType); - builder.emitStore(addr, newLoad); - } - if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) - { - builder.addAlignedAddressDecoration( - addr, - alignedAttr->getAlignment()); - } - auto unpackedVal = - loweredElementTypeInfo.convertLoweredToOriginal.apply( - builder, - loweredElementTypeInfo.originalType, - addr); - user->replaceUsesWith(unpackedVal); - user->removeAndDeallocate(); - break; + addr = originalVal->getOperand(0); } - case kIROp_Store: - case kIROp_RWStructuredBufferStore: - case kIROp_StructuredBufferAppend: + else { - // Use must be the dest operand of the store inst. - if (use != user->getOperands() + 0) - break; - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto originalVal = getStoreVal(user); - IRInst* addr = getBufferAddr(builder, user); - if (addr) - { - if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) - { - builder.addAlignedAddressDecoration( - addr, - alignedAttr->getAlignment()); - } - - loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, addr, originalVal); - user->removeAndDeallocate(); - } - else if (auto sbAppend = as<IRStructuredBufferAppend>(user)) - { - builder.setInsertBefore(sbAppend); - addr = builder.emitVar(loweredElementTypeInfo.loweredType); - loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, addr, originalVal); - auto packedVal = builder.emitLoad(addr); - sbAppend->setOperand(1, packedVal); - } - else - { - SLANG_UNREACHABLE("unhandled store type"); - } - break; + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); } - case kIROp_GetElementPtr: - case kIROp_FieldAddress: + auto packedVal = builder.emitLoad(addr); + sbAppend->setOperand(1, packedVal); + } + else + { + IRInst* addr = getBufferAddr(builder, user, ptrVal); + if (auto alignedAttr = user->findAttr<IRAlignedAttr>()) { - // If original type is an array, the lowered type will be a struct. - // In that case, all existing address insts should be appended with - // a field extract. - if (as<IRArrayType>(originalElementType)) - { - builder.setInsertBefore(user); - List<IRInst*> args; - for (UInt i = 0; i < user->getOperandCount(); i++) - args.add(user->getOperand(i)); - auto newArrayPtrVal = builder.emitFieldAddress( - builder.getPtrType( - loweredElementTypeInfo.loweredInnerArrayType), - ptrVal, - loweredElementTypeInfo.loweredInnerStructKey); - builder.replaceOperand(use, newArrayPtrVal); - ptrValsWorkList.add(user); - } - else if (as<IRMatrixType>(originalElementType)) - { - // We are tring to get a pointer to a lowered matrix element. - // We process this insts at a later phase. - SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); - matrixAddrInsts.add(MatrixAddrWorkItem{user, config}); - } - else - { - // If we getting a derived address from the pointer, we need - // to recursively lower the new address. We do so by pushing - // the address inst into the work list. - ptrValsWorkList.add(user); - } + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); } - break; - case kIROp_RWStructuredBufferGetElementPtr: - case kIROp_GetOffsetPtr: - ptrValsWorkList.add(user); - break; - case kIROp_StructuredBufferGetDimensions: - break; - case kIROp_Call: + if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref) + { + auto valAddr = originalVal->getOperand(0); + auto storageVal = builder.emitLoad(valAddr); + builder.emitStore(addr, storageVal); + } + else { - // If a structured buffer or pointer typed value is used directly as - // an argument, we don't need to do any marshalling here. - if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType())) - break; - if (options.lowerBufferPointer && - as<IRPtrType>(ptrVal->getDataType())) - break; - // If we are calling a function with an l-value pointer from buffer - // access, we need to materialize the object as a local variable, - // and pass the address of the local variable to the function. - builder.setInsertBefore(user); - auto unpackedVal = - loweredElementTypeInfo.convertLoweredToOriginal.apply( - builder, - (IRType*)originalElementType, - ptrVal); - auto var = builder.emitVar((IRType*)originalElementType); - builder.emitStore(var, unpackedVal); - use->set(var); - builder.setInsertAfter(user); - auto newVal = builder.emitLoad(var); loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, ptrVal, newVal); + .applyDestinationDriven(builder, addr, originalVal); } - break; - default: - break; + user->removeAndDeallocate(); } - }); - } + return; + } + default: + break; + } + // If the pointer is used in any other way that we don't recognize, + // preserve it as is without translation. + builder.setInsertBefore(user); + builder.replaceOperand(use, ptrVal); + }); + + if (!castInst->hasUses()) + castInst->removeAndDeallocate(); + } - // Replace all remaining uses of bufferType to loweredBufferType, these uses are - // non-operational and should be directly replaceable, such as uses in `IRFuncType`. - bufferType->replaceUsesWith(loweredBufferType); - bufferType->removeAndDeallocate(); + void collectInstsOfType(List<IRCastStorageToLogicalBase*>& insts, IRInst* root, IROp op) + { + if (root->getOp() == op) + { + insts.add((IRCastStorageToLogicalBase*)root); + return; + } + for (auto child : root->getChildren()) + { + collectInstsOfType(insts, child, op); } + } - // Process all matrix address uses. - lowerMatrixAddresses(module, matrixAddrInsts); + void materializeStorageToLogicalCasts(IRInst* root) + { + // We will process all CastStorageToLogical insts first, before + // processing all CastStorageToLogicalDeref. + // This is because when we materialize a + // `store(CastStorageToLogical(addr), CastStorageToLogicalDeref(src))`, + // we can just fold out CastStorageToLogicalDeref and emit + // `store(addr, load(src))` instead. + // If we materialized `CastStorageToLogicalDeref` first we will + // miss this opportunity and generate more bloated code. + // + List<IRCastStorageToLogicalBase*> castInsts; + collectInstsOfType(castInsts, root, kIROp_CastStorageToLogical); + for (auto inst : castInsts) + materializeStorageToLogicalCastsImpl(inst); + + castInsts.clear(); + collectInstsOfType(castInsts, root, kIROp_CastStorageToLogicalDeref); + for (auto inst : castInsts) + materializeStorageToLogicalCastsImpl(inst); } // Lower all getElementPtr insts of a lowered matrix out of existance. - void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts) + void lowerMatrixAddresses(IRModule* module, MatrixAddrWorkItem workItem) { IRBuilder builder(module); - for (auto workItem : matrixAddrInsts) - { - auto majorAddr = workItem.matrixAddrInst; - auto majorGEP = as<IRGetElementPtr>(majorAddr); - SLANG_ASSERT(majorGEP); - auto loweredMatrixType = - cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); - auto matrixTypeInfo = getTypeLoweringMap(workItem.config) - .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); - SLANG_ASSERT(matrixTypeInfo); - auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); - auto rowCount = getIntVal(matrixType->getRowCount()); - traverseUses( - majorAddr, - [&](IRUse* use) + auto majorAddr = workItem.matrixAddrInst; + auto majorGEP = as<IRGetElementPtr>(majorAddr); + SLANG_ASSERT(majorGEP); + auto baseCast = as<IRCastStorageToLogical>(majorGEP->getBase()); + SLANG_ASSERT(baseCast); + auto storageBase = baseCast->getOperand(0); + auto loweredMatrixType = cast<IRPtrTypeBase>(storageBase->getFullType())->getValueType(); + auto matrixTypeInfo = + getTypeLoweringMap(workItem.config).mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); + SLANG_ASSERT(matrixTypeInfo); + if (matrixTypeInfo->loweredType == matrixTypeInfo->originalType) + return; + auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); + auto colCount = getIntVal(matrixType->getColumnCount()); + traverseUses( + majorAddr, + [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) { - auto user = use->getUser(); - builder.setInsertBefore(user); - switch (user->getOp()) + case kIROp_Load: { - case kIROp_Load: + IRInst* resultInst = nullptr; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - IRInst* resultInst = nullptr; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - List<IRInst*> args; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vector = - builder.emitLoad(builder.emitElementAddress(dataPtr, i)); - auto element = - builder.emitElementExtract(vector, majorGEP->getIndex()); - args.add(element); - } - resultInst = builder.emitMakeVector( - builder.getVectorType( - matrixType->getElementType(), - (IRIntegerValue)args.getCount()), - args); - } - else + List<IRInst*> args; + for (IRIntegerValue i = 0; i < colCount; i++) { + auto vector = + builder.emitLoad(builder.emitElementAddress(dataPtr, i)); auto element = - builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - resultInst = builder.emitLoad(element); + builder.emitElementExtract(vector, majorGEP->getIndex()); + args.add(element); } - user->replaceUsesWith(resultInst); - user->removeAndDeallocate(); + resultInst = builder.emitMakeVector( + builder.getVectorType( + matrixType->getElementType(), + (IRIntegerValue)args.getCount()), + args); } - break; - case kIROp_Store: + else { - auto storeInst = cast<IRStore>(user); - if (storeInst->getOperand(0) != majorAddr) - break; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vectorAddr = builder.emitElementAddress(dataPtr, i); - auto elementAddr = builder.emitElementAddress( - vectorAddr, - majorGEP->getIndex()); - builder.emitStore( - elementAddr, - builder.emitElementExtract(storeInst->getVal(), i)); - } - } - else - { - auto rowAddr = - builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - builder.emitStore(rowAddr, storeInst->getVal()); - user->removeAndDeallocate(); - } - break; + auto element = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + resultInst = builder.emitLoad(element); } - case kIROp_GetElementPtr: + user->replaceUsesWith(resultInst); + user->removeAndDeallocate(); + } + break; + case kIROp_Store: + { + auto storeInst = cast<IRStore>(user); + if (storeInst->getOperand(0) != majorAddr) + break; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto gep2 = cast<IRGetElementPtr>(user); - auto rowIndex = majorGEP->getIndex(); - auto colIndex = gep2->getIndex(); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + for (IRIntegerValue i = 0; i < colCount; i++) { - Swap(rowIndex, colIndex); + auto vectorAddr = builder.emitElementAddress(dataPtr, i); + auto elementAddr = + builder.emitElementAddress(vectorAddr, majorGEP->getIndex()); + builder.emitStore( + elementAddr, + builder.emitElementExtract(storeInst->getVal(), i)); } - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); - auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); - gep2->replaceUsesWith(elementAddr); - gep2->removeAndDeallocate(); - break; } - default: - SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs " - "storage lowering."); + else + { + auto rowAddr = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + builder.emitStore(rowAddr, storeInst->getVal()); + user->removeAndDeallocate(); + } break; } - }); - } + case kIROp_GetElementPtr: + { + auto gep2 = cast<IRGetElementPtr>(user); + auto rowIndex = majorGEP->getIndex(); + auto colIndex = gep2->getIndex(); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + Swap(rowIndex, colIndex); + } + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); + auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); + gep2->replaceUsesWith(elementAddr); + gep2->removeAndDeallocate(); + break; + } + default: + SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs " + "storage lowering."); + break; + } + }); + if (!majorAddr->hasUses()) + majorAddr->removeAndDeallocate(); } }; |
