summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-any-value-marshalling.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-11 03:18:20 -0800
committerGitHub <noreply@github.com>2024-12-11 11:18:20 +0000
commitf573c15868234ab6013faf0fc2b93a72fa89f21d (patch)
tree500d91ceb22c61af2b4231c5f435c6722c346977 /source/slang/slang-ir-any-value-marshalling.cpp
parentf68768887b02df5080d35f0f32b035ef67764cd0 (diff)
Fix anyvalue marshalling for matrix and 64 bit types. (#5827)
* Fix anyvalue marshalling for matrix types. * Add support for 64bit types marshalling. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-any-value-marshalling.cpp')
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp110
1 files changed, 96 insertions, 14 deletions
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 163c5a808..6dc01d495 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -103,6 +103,12 @@ struct AnyValueMarshallingContext
intraFieldOffset = 0;
}
}
+ void ensureOffsetAt8ByteBoundary()
+ {
+ ensureOffsetAt4ByteBoundary();
+ if ((fieldOffset & 1) != 0)
+ fieldOffset++;
+ }
void ensureOffsetAt2ByteBoundary()
{
if (intraFieldOffset == 0)
@@ -146,6 +152,7 @@ struct AnyValueMarshallingContext
case kIROp_BoolType:
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
+ case kIROp_PtrType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
case kIROp_VectorType:
@@ -166,17 +173,36 @@ struct AnyValueMarshallingContext
auto matrixType = static_cast<IRMatrixType*>(dataType);
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
- for (IRIntegerValue i = 0; i < colCount; i++)
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
{
- auto col = builder->emitElementAddress(
- concreteTypedVar,
- builder->getIntValue(builder->getIntType(), i));
- for (IRIntegerValue j = 0; j < rowCount; j++)
+ for (IRIntegerValue i = 0; i < colCount; i++)
+ {
+ for (IRIntegerValue j = 0; j < rowCount; j++)
+ {
+ auto row = builder->emitElementAddress(
+ concreteTypedVar,
+ builder->getIntValue(builder->getIntType(), j));
+ auto element = builder->emitElementAddress(
+ row,
+ builder->getIntValue(builder->getIntType(), i));
+ emitMarshallingCode(builder, context, element);
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue i = 0; i < rowCount; i++)
{
- auto element = builder->emitElementAddress(
- col,
- builder->getIntValue(builder->getIntType(), j));
- emitMarshallingCode(builder, context, element);
+ auto row = builder->emitElementAddress(
+ concreteTypedVar,
+ builder->getIntValue(builder->getIntType(), i));
+ for (IRIntegerValue j = 0; j < colCount; j++)
+ {
+ auto element = builder->emitElementAddress(
+ row,
+ builder->getIntValue(builder->getIntType(), j));
+ emitMarshallingCode(builder, context, element);
+ }
}
}
break;
@@ -348,11 +374,39 @@ struct AnyValueMarshallingContext
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
+ case kIROp_PtrType:
#if SLANG_PTR_IS_64
case kIROp_UIntPtrType:
case kIROp_IntPtrType:
#endif
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ ensureOffsetAt8ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcVal = builder->emitLoad(concreteVar);
+ auto dstVal = builder->emitBitCast(builder->getUInt64Type(), srcVal);
+ auto lowBits = builder->emitCast(builder->getUIntType(), dstVal);
+ auto highBits = builder->emitShr(
+ builder->getUInt64Type(),
+ dstVal,
+ builder->getIntValue(builder->getIntType(), 32));
+ highBits = builder->emitCast(builder->getUIntType(), highBits);
+
+ auto dstAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr, lowBits);
+ fieldOffset++;
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ dstAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr, lowBits);
+ fieldOffset++;
+ }
+ }
break;
default:
SLANG_UNREACHABLE("unknown basic type");
@@ -545,7 +599,34 @@ struct AnyValueMarshallingContext
case kIROp_DoubleType:
case kIROp_Int8Type:
case kIROp_UInt8Type:
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ case kIROp_PtrType:
+#if SLANG_PTR_IS_64
+ case kIROp_IntPtrType:
+ case kIROp_UIntPtrType:
+#endif
+ ensureOffsetAt8ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ auto lowBits = builder->emitLoad(srcAddr);
+ fieldOffset++;
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcAddr1 = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ fieldOffset++;
+ auto highBits = builder->emitLoad(srcAddr1);
+ auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
+ if (dataType->getOp() != kIROp_UInt64Type)
+ combinedBits = builder->emitBitCast(dataType, combinedBits);
+ builder->emitStore(concreteVar, combinedBits);
+ }
+ }
break;
default:
SLANG_UNREACHABLE("unknown basic type");
@@ -735,7 +816,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
- return -1;
+ case kIROp_PtrType:
+ return alignUp(offset, 8) + 8;
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
@@ -762,9 +844,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
- for (IRIntegerValue i = 0; i < colCount; i++)
+ for (IRIntegerValue i = 0; i < rowCount; i++)
{
- for (IRIntegerValue j = 0; j < rowCount; j++)
+ for (IRIntegerValue j = 0; j < colCount; j++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
if (offset < 0)