summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-11 03:18:20 -0800
committerGitHub <noreply@github.com>2024-12-11 11:18:20 +0000
commitf573c15868234ab6013faf0fc2b93a72fa89f21d (patch)
tree500d91ceb22c61af2b4231c5f435c6722c346977
parentf68768887b02df5080d35f0f32b035ef67764cd0 (diff)
Fix anyvalue marshalling for matrix and 64 bit types. (#5827)
* Fix anyvalue marshalling for matrix types. * Add support for 64bit types marshalling. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
-rw-r--r--source/slang/slang-emit-spirv.cpp16
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp110
-rw-r--r--tests/language-feature/anyvalue-layout.slang44
-rw-r--r--tests/language-feature/anyvalue-matrix-layout.slang30
4 files changed, 186 insertions, 14 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index f4aa900db..8759ea9d4 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3285,6 +3285,19 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return nullptr;
}
+ SpvInst* emitMakeUInt64(SpvInstParent* parent, IRInst* inst)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto vec = emitOpCompositeConstruct(
+ parent,
+ nullptr,
+ builder.getVectorType(builder.getUIntType(), 2),
+ inst->getOperand(0),
+ inst->getOperand(1));
+ return emitOpBitcast(parent, inst, inst->getDataType(), vec);
+ }
+
// The instructions that appear inside the basic blocks of
// functions are what we will call "local" instructions.
//
@@ -3391,6 +3404,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_BitCast:
result = emitOpBitcast(parent, inst, inst->getDataType(), inst->getOperand(0));
break;
+ case kIROp_MakeUInt64:
+ result = emitMakeUInt64(parent, inst);
+ break;
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 163c5a808..6dc01d495 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -103,6 +103,12 @@ struct AnyValueMarshallingContext
intraFieldOffset = 0;
}
}
+ void ensureOffsetAt8ByteBoundary()
+ {
+ ensureOffsetAt4ByteBoundary();
+ if ((fieldOffset & 1) != 0)
+ fieldOffset++;
+ }
void ensureOffsetAt2ByteBoundary()
{
if (intraFieldOffset == 0)
@@ -146,6 +152,7 @@ struct AnyValueMarshallingContext
case kIROp_BoolType:
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
+ case kIROp_PtrType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
case kIROp_VectorType:
@@ -166,17 +173,36 @@ struct AnyValueMarshallingContext
auto matrixType = static_cast<IRMatrixType*>(dataType);
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
- for (IRIntegerValue i = 0; i < colCount; i++)
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
{
- auto col = builder->emitElementAddress(
- concreteTypedVar,
- builder->getIntValue(builder->getIntType(), i));
- for (IRIntegerValue j = 0; j < rowCount; j++)
+ for (IRIntegerValue i = 0; i < colCount; i++)
+ {
+ for (IRIntegerValue j = 0; j < rowCount; j++)
+ {
+ auto row = builder->emitElementAddress(
+ concreteTypedVar,
+ builder->getIntValue(builder->getIntType(), j));
+ auto element = builder->emitElementAddress(
+ row,
+ builder->getIntValue(builder->getIntType(), i));
+ emitMarshallingCode(builder, context, element);
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue i = 0; i < rowCount; i++)
{
- auto element = builder->emitElementAddress(
- col,
- builder->getIntValue(builder->getIntType(), j));
- emitMarshallingCode(builder, context, element);
+ auto row = builder->emitElementAddress(
+ concreteTypedVar,
+ builder->getIntValue(builder->getIntType(), i));
+ for (IRIntegerValue j = 0; j < colCount; j++)
+ {
+ auto element = builder->emitElementAddress(
+ row,
+ builder->getIntValue(builder->getIntType(), j));
+ emitMarshallingCode(builder, context, element);
+ }
}
}
break;
@@ -348,11 +374,39 @@ struct AnyValueMarshallingContext
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
+ case kIROp_PtrType:
#if SLANG_PTR_IS_64
case kIROp_UIntPtrType:
case kIROp_IntPtrType:
#endif
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ ensureOffsetAt8ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcVal = builder->emitLoad(concreteVar);
+ auto dstVal = builder->emitBitCast(builder->getUInt64Type(), srcVal);
+ auto lowBits = builder->emitCast(builder->getUIntType(), dstVal);
+ auto highBits = builder->emitShr(
+ builder->getUInt64Type(),
+ dstVal,
+ builder->getIntValue(builder->getIntType(), 32));
+ highBits = builder->emitCast(builder->getUIntType(), highBits);
+
+ auto dstAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr, lowBits);
+ fieldOffset++;
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ dstAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr, lowBits);
+ fieldOffset++;
+ }
+ }
break;
default:
SLANG_UNREACHABLE("unknown basic type");
@@ -545,7 +599,34 @@ struct AnyValueMarshallingContext
case kIROp_DoubleType:
case kIROp_Int8Type:
case kIROp_UInt8Type:
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ case kIROp_PtrType:
+#if SLANG_PTR_IS_64
+ case kIROp_IntPtrType:
+ case kIROp_UIntPtrType:
+#endif
+ ensureOffsetAt8ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ auto lowBits = builder->emitLoad(srcAddr);
+ fieldOffset++;
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcAddr1 = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ fieldOffset++;
+ auto highBits = builder->emitLoad(srcAddr1);
+ auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
+ if (dataType->getOp() != kIROp_UInt64Type)
+ combinedBits = builder->emitBitCast(dataType, combinedBits);
+ builder->emitStore(concreteVar, combinedBits);
+ }
+ }
break;
default:
SLANG_UNREACHABLE("unknown basic type");
@@ -735,7 +816,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
- return -1;
+ case kIROp_PtrType:
+ return alignUp(offset, 8) + 8;
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
@@ -762,9 +844,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
- for (IRIntegerValue i = 0; i < colCount; i++)
+ for (IRIntegerValue i = 0; i < rowCount; i++)
{
- for (IRIntegerValue j = 0; j < rowCount; j++)
+ for (IRIntegerValue j = 0; j < colCount; j++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
if (offset < 0)
diff --git a/tests/language-feature/anyvalue-layout.slang b/tests/language-feature/anyvalue-layout.slang
new file mode 100644
index 000000000..15e58f093
--- /dev/null
+++ b/tests/language-feature/anyvalue-layout.slang
@@ -0,0 +1,44 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -dx12 -use-dxil -profile cs_6_1 -output-using-type
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -output-using-type
+
+interface IFoo
+{
+ float getVal();
+ uint64_t getPtrVal();
+}
+
+struct Foo : IFoo
+{
+ column_major float3x2 m;
+ int x;
+ uint64_t ptr;
+ float getVal()
+ {
+ return m[2][0];
+ }
+ uint64_t getPtrVal()
+ {
+ return ptr;
+ }
+}
+
+//TEST_INPUT: type_conformance Foo:IFoo = 0
+
+//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0 0 0 1 2], stride=4)
+RWStructuredBuffer<IFoo> gFoo;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 3.0
+ outputBuffer[0] = gFoo[0].getVal();
+
+ // CHECK: 1.0
+ outputBuffer[1] = gFoo[0].getPtrVal()&0xFFFFFFFF;
+
+ // CHECK: 2.0
+ outputBuffer[2] = gFoo[0].getPtrVal()>>32;
+} \ No newline at end of file
diff --git a/tests/language-feature/anyvalue-matrix-layout.slang b/tests/language-feature/anyvalue-matrix-layout.slang
new file mode 100644
index 000000000..351eec81b
--- /dev/null
+++ b/tests/language-feature/anyvalue-matrix-layout.slang
@@ -0,0 +1,30 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type
+
+interface IFoo
+{
+ float getVal();
+}
+
+struct Foo : IFoo
+{
+ column_major float3x2 m;
+ float getVal()
+ {
+ return m[2][0];
+ }
+}
+
+//TEST_INPUT: type_conformance Foo:IFoo = 0
+
+//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0], stride=4)
+RWStructuredBuffer<IFoo> gFoo;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 3.0
+ outputBuffer[0] = gFoo[0].getVal();
+} \ No newline at end of file