diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-11 03:18:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-11 11:18:20 +0000 |
| commit | f573c15868234ab6013faf0fc2b93a72fa89f21d (patch) | |
| tree | 500d91ceb22c61af2b4231c5f435c6722c346977 /source | |
| parent | f68768887b02df5080d35f0f32b035ef67764cd0 (diff) | |
Fix anyvalue marshalling for matrix and 64 bit types. (#5827)
* Fix anyvalue marshalling for matrix types.
* Add support for 64bit types marshalling.
---------
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-any-value-marshalling.cpp | 110 |
2 files changed, 112 insertions, 14 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index f4aa900db..8759ea9d4 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3285,6 +3285,19 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return nullptr; } + SpvInst* emitMakeUInt64(SpvInstParent* parent, IRInst* inst) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto vec = emitOpCompositeConstruct( + parent, + nullptr, + builder.getVectorType(builder.getUIntType(), 2), + inst->getOperand(0), + inst->getOperand(1)); + return emitOpBitcast(parent, inst, inst->getDataType(), vec); + } + // The instructions that appear inside the basic blocks of // functions are what we will call "local" instructions. // @@ -3391,6 +3404,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_BitCast: result = emitOpBitcast(parent, inst, inst->getDataType(), inst->getOperand(0)); break; + case kIROp_MakeUInt64: + result = emitMakeUInt64(parent, inst); + break; case kIROp_Add: case kIROp_Sub: case kIROp_Mul: diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 163c5a808..6dc01d495 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -103,6 +103,12 @@ struct AnyValueMarshallingContext intraFieldOffset = 0; } } + void ensureOffsetAt8ByteBoundary() + { + ensureOffsetAt4ByteBoundary(); + if ((fieldOffset & 1) != 0) + fieldOffset++; + } void ensureOffsetAt2ByteBoundary() { if (intraFieldOffset == 0) @@ -146,6 +152,7 @@ struct AnyValueMarshallingContext case kIROp_BoolType: case kIROp_IntPtrType: case kIROp_UIntPtrType: + case kIROp_PtrType: context->marshalBasicType(builder, dataType, concreteTypedVar); break; case kIROp_VectorType: @@ -166,17 +173,36 @@ struct AnyValueMarshallingContext auto matrixType = static_cast<IRMatrixType*>(dataType); auto colCount = getIntVal(matrixType->getColumnCount()); auto rowCount = getIntVal(matrixType->getRowCount()); - for (IRIntegerValue i = 0; i < colCount; i++) + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto col = builder->emitElementAddress( - concreteTypedVar, - builder->getIntValue(builder->getIntType(), i)); - for (IRIntegerValue j = 0; j < rowCount; j++) + for (IRIntegerValue i = 0; i < colCount; i++) + { + for (IRIntegerValue j = 0; j < rowCount; j++) + { + auto row = builder->emitElementAddress( + concreteTypedVar, + builder->getIntValue(builder->getIntType(), j)); + auto element = builder->emitElementAddress( + row, + builder->getIntValue(builder->getIntType(), i)); + emitMarshallingCode(builder, context, element); + } + } + } + else + { + for (IRIntegerValue i = 0; i < rowCount; i++) { - auto element = builder->emitElementAddress( - col, - builder->getIntValue(builder->getIntType(), j)); - emitMarshallingCode(builder, context, element); + auto row = builder->emitElementAddress( + concreteTypedVar, + builder->getIntValue(builder->getIntType(), i)); + for (IRIntegerValue j = 0; j < colCount; j++) + { + auto element = builder->emitElementAddress( + row, + builder->getIntValue(builder->getIntType(), j)); + emitMarshallingCode(builder, context, element); + } } } break; @@ -348,11 +374,39 @@ struct AnyValueMarshallingContext case kIROp_UInt64Type: case kIROp_Int64Type: case kIROp_DoubleType: + case kIROp_PtrType: #if SLANG_PTR_IS_64 case kIROp_UIntPtrType: case kIROp_IntPtrType: #endif - SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); + ensureOffsetAt8ByteBoundary(); + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + auto srcVal = builder->emitLoad(concreteVar); + auto dstVal = builder->emitBitCast(builder->getUInt64Type(), srcVal); + auto lowBits = builder->emitCast(builder->getUIntType(), dstVal); + auto highBits = builder->emitShr( + builder->getUInt64Type(), + dstVal, + builder->getIntValue(builder->getIntType(), 32)); + highBits = builder->emitCast(builder->getUIntType(), highBits); + + auto dstAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + builder->emitStore(dstAddr, lowBits); + fieldOffset++; + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + dstAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + builder->emitStore(dstAddr, lowBits); + fieldOffset++; + } + } break; default: SLANG_UNREACHABLE("unknown basic type"); @@ -545,7 +599,34 @@ struct AnyValueMarshallingContext case kIROp_DoubleType: case kIROp_Int8Type: case kIROp_UInt8Type: - SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); + case kIROp_PtrType: +#if SLANG_PTR_IS_64 + case kIROp_IntPtrType: + case kIROp_UIntPtrType: +#endif + ensureOffsetAt8ByteBoundary(); + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + auto srcAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + auto lowBits = builder->emitLoad(srcAddr); + fieldOffset++; + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + auto srcAddr1 = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + fieldOffset++; + auto highBits = builder->emitLoad(srcAddr1); + auto combinedBits = builder->emitMakeUInt64(lowBits, highBits); + if (dataType->getOp() != kIROp_UInt64Type) + combinedBits = builder->emitBitCast(dataType, combinedBits); + builder->emitStore(concreteVar, combinedBits); + } + } break; default: SLANG_UNREACHABLE("unknown basic type"); @@ -735,7 +816,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) case kIROp_UInt64Type: case kIROp_Int64Type: case kIROp_DoubleType: - return -1; + case kIROp_PtrType: + return alignUp(offset, 8) + 8; case kIROp_Int16Type: case kIROp_UInt16Type: case kIROp_HalfType: @@ -762,9 +844,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) auto elementType = matrixType->getElementType(); auto colCount = getIntVal(matrixType->getColumnCount()); auto rowCount = getIntVal(matrixType->getRowCount()); - for (IRIntegerValue i = 0; i < colCount; i++) + for (IRIntegerValue i = 0; i < rowCount; i++) { - for (IRIntegerValue j = 0; j < rowCount; j++) + for (IRIntegerValue j = 0; j < colCount; j++) { offset = _getAnyValueSizeRaw(elementType, offset); if (offset < 0) |
