summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-any-value-marshalling.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-any-value-marshalling.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-any-value-marshalling.cpp')
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp938
1 files changed, 492 insertions, 446 deletions
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index f09294aa9..b51ab4608 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -2,147 +2,153 @@
#include "../core/slang-math.h"
#include "slang-ir-generics-lowering-context.h"
-#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir.h"
#include "slang-legalize-types.h"
namespace Slang
{
- // This is a subpass of generics lowering IR transformation.
- // This pass generates packing/unpacking functions for `AnyValue`s,
- // and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these
- // functions.
- struct AnyValueMarshallingContext
- {
- SharedGenericsLoweringContext* sharedContext;
+// This is a subpass of generics lowering IR transformation.
+// This pass generates packing/unpacking functions for `AnyValue`s,
+// and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these
+// functions.
+struct AnyValueMarshallingContext
+{
+ SharedGenericsLoweringContext* sharedContext;
- // Stores information about generated `AnyValue` struct types.
- struct AnyValueTypeInfo : RefObject
- {
- IRType* type; // The generated IR value for the `AnyValue<N>` struct type.
- List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type.
- };
+ // Stores information about generated `AnyValue` struct types.
+ struct AnyValueTypeInfo : RefObject
+ {
+ IRType* type; // The generated IR value for the `AnyValue<N>` struct type.
+ List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type.
+ };
- Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes;
+ Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes;
- struct MarshallingFunctionKey
+ struct MarshallingFunctionKey
+ {
+ IRType* originalType;
+ IRIntegerValue anyValueSize;
+ bool operator==(MarshallingFunctionKey other) const
{
- IRType* originalType;
- IRIntegerValue anyValueSize;
- bool operator ==(MarshallingFunctionKey other) const
- {
- return originalType == other.originalType && anyValueSize == other.anyValueSize;
- }
- HashCode getHashCode() const
- {
- return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize));
- }
- };
-
- struct MarshallingFunctionSet
+ return originalType == other.originalType && anyValueSize == other.anyValueSize;
+ }
+ HashCode getHashCode() const
{
- IRFunc* packFunc;
- IRFunc* unpackFunc;
- };
+ return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize));
+ }
+ };
- // Stores the generated packing/unpacking functions for lookup.
- Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions;
+ struct MarshallingFunctionSet
+ {
+ IRFunc* packFunc;
+ IRFunc* unpackFunc;
+ };
- AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type)
+ // Stores the generated packing/unpacking functions for lookup.
+ Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions;
+
+ AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type)
+ {
+ auto size = getIntVal(type->getSize());
+ if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size))
+ return typeInfo->Ptr();
+ RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo();
+ IRBuilder builder(sharedContext->module);
+ builder.setInsertBefore(type);
+ auto structType = builder.createStructType();
+ info->type = structType;
+ StringBuilder nameSb;
+ nameSb << "AnyValue" << size;
+ builder.addExportDecoration(structType, nameSb.getUnownedSlice());
+ auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t);
+ for (decltype(fieldCount) i = 0; i < fieldCount; i++)
{
- auto size = getIntVal(type->getSize());
- if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size))
- return typeInfo->Ptr();
- RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo();
- IRBuilder builder(sharedContext->module);
- builder.setInsertBefore(type);
- auto structType = builder.createStructType();
- info->type = structType;
- StringBuilder nameSb;
- nameSb << "AnyValue" << size;
- builder.addExportDecoration(structType, nameSb.getUnownedSlice());
- auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t);
- for (decltype(fieldCount) i = 0; i < fieldCount; i++)
- {
- auto key = builder.createStructKey();
- nameSb.clear();
- nameSb << "field" << i;
- builder.addNameHintDecoration(key, nameSb.getUnownedSlice());
- nameSb << "_anyVal" << size;
- builder.addExportDecoration(key, nameSb.getUnownedSlice());
- builder.createStructField(structType, key, builder.getUIntType());
- info->fieldKeys.add(key);
- }
- generatedAnyValueTypes[size] = info;
- return info.Ptr();
+ auto key = builder.createStructKey();
+ nameSb.clear();
+ nameSb << "field" << i;
+ builder.addNameHintDecoration(key, nameSb.getUnownedSlice());
+ nameSb << "_anyVal" << size;
+ builder.addExportDecoration(key, nameSb.getUnownedSlice());
+ builder.createStructField(structType, key, builder.getUIntType());
+ info->fieldKeys.add(key);
}
+ generatedAnyValueTypes[size] = info;
+ return info.Ptr();
+ }
- struct TypeMarshallingContext
+ struct TypeMarshallingContext
+ {
+ AnyValueTypeInfo* anyValInfo;
+ uint32_t fieldOffset;
+ uint32_t intraFieldOffset;
+ IRType* uintPtrType;
+ IRInst* anyValueVar;
+ // Defines what to do with basic typed data elements.
+ virtual void marshalBasicType(
+ IRBuilder* builder,
+ IRType* dataType,
+ IRInst* concreteTypedVar) = 0;
+ // Defines what to do with resource handle elements.
+ virtual void marshalResourceHandle(
+ IRBuilder* builder,
+ IRType* dataType,
+ IRInst* concreteTypedVar) = 0;
+
+ void ensureOffsetAt4ByteBoundary()
{
- AnyValueTypeInfo* anyValInfo;
- uint32_t fieldOffset;
- uint32_t intraFieldOffset;
- IRType* uintPtrType;
- IRInst* anyValueVar;
- // Defines what to do with basic typed data elements.
- virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
- // Defines what to do with resource handle elements.
- virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
-
- void ensureOffsetAt4ByteBoundary()
- {
- if (intraFieldOffset)
- {
- fieldOffset++;
- intraFieldOffset = 0;
- }
- }
- void ensureOffsetAt2ByteBoundary()
+ if (intraFieldOffset)
{
- if (intraFieldOffset == 0)
- return;
- if (intraFieldOffset <= 2)
- {
- intraFieldOffset = 2;
- return;
- }
fieldOffset++;
intraFieldOffset = 0;
- return;
}
- void advanceOffset(uint32_t bytes)
+ }
+ void ensureOffsetAt2ByteBoundary()
+ {
+ if (intraFieldOffset == 0)
+ return;
+ if (intraFieldOffset <= 2)
{
- intraFieldOffset += bytes;
- fieldOffset += intraFieldOffset / 4;
- intraFieldOffset = intraFieldOffset % 4;
+ intraFieldOffset = 2;
+ return;
}
- };
+ fieldOffset++;
+ intraFieldOffset = 0;
+ return;
+ }
+ void advanceOffset(uint32_t bytes)
+ {
+ intraFieldOffset += bytes;
+ fieldOffset += intraFieldOffset / 4;
+ intraFieldOffset = intraFieldOffset % 4;
+ }
+ };
- void emitMarshallingCode(
- IRBuilder* builder,
- TypeMarshallingContext* context,
- IRInst* concreteTypedVar)
+ void emitMarshallingCode(
+ IRBuilder* builder,
+ TypeMarshallingContext* context,
+ IRInst* concreteTypedVar)
+ {
+ auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType();
+ switch (dataType->getOp())
{
- auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType();
- switch (dataType->getOp())
- {
- case kIROp_IntType:
- case kIROp_FloatType:
- case kIROp_UIntType:
- case kIROp_UInt64Type:
- case kIROp_Int64Type:
- case kIROp_DoubleType:
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_HalfType:
- case kIROp_BoolType:
- case kIROp_IntPtrType:
- case kIROp_UIntPtrType:
- context->marshalBasicType(builder, dataType, concreteTypedVar);
- break;
- case kIROp_VectorType:
+ case kIROp_IntType:
+ case kIROp_FloatType:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_Int64Type:
+ case kIROp_DoubleType:
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_HalfType:
+ case kIROp_BoolType:
+ case kIROp_IntPtrType:
+ case kIROp_UIntPtrType:
+ context->marshalBasicType(builder, dataType, concreteTypedVar);
+ break;
+ case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(dataType);
auto elementCount = getIntVal(vectorType->getElementCount());
@@ -155,7 +161,7 @@ namespace Slang
}
break;
}
- case kIROp_MatrixType:
+ case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(dataType);
auto colCount = getIntVal(matrixType->getColumnCount());
@@ -175,7 +181,7 @@ namespace Slang
}
break;
}
- case kIROp_StructType:
+ case kIROp_StructType:
{
auto structType = cast<IRStructType>(dataType);
for (auto field : structType->getFields())
@@ -188,7 +194,7 @@ namespace Slang
}
break;
}
- case kIROp_ArrayType:
+ case kIROp_ArrayType:
{
auto arrayType = cast<IRArrayType>(dataType);
for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
@@ -200,7 +206,7 @@ namespace Slang
}
break;
}
- case kIROp_AnyValueType:
+ case kIROp_AnyValueType:
{
auto anyValType = cast<IRAnyValueType>(dataType);
auto info = ensureAnyValueType(anyValType);
@@ -214,27 +220,28 @@ namespace Slang
}
break;
}
- default:
- if (isResourceType(dataType))
- {
- context->marshalResourceHandle(builder, dataType, concreteTypedVar);
- return;
- }
- SLANG_UNIMPLEMENTED_X("Unimplemented type packing");
- break;
+ default:
+ if (isResourceType(dataType))
+ {
+ context->marshalResourceHandle(builder, dataType, concreteTypedVar);
+ return;
}
+ SLANG_UNIMPLEMENTED_X("Unimplemented type packing");
+ break;
}
+ }
- struct TypePackingContext : TypeMarshallingContext
+ struct TypePackingContext : TypeMarshallingContext
+ {
+ virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar)
+ override
{
- virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
+ switch (dataType->getOp())
{
- switch (dataType->getOp())
- {
- case kIROp_IntType:
- case kIROp_FloatType:
+ case kIROp_IntType:
+ case kIROp_FloatType:
#if SLANG_PTR_IS_32
- case kIROp_IntPtrType:
+ case kIROp_IntPtrType:
#endif
{
ensureOffsetAt4ByteBoundary();
@@ -251,14 +258,21 @@ namespace Slang
advanceOffset(4);
break;
}
- case kIROp_BoolType:
+ case kIROp_BoolType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
- IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) };
- auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args);
+ IRInst* args[] = {
+ srcVal,
+ builder->getIntValue(builder->getUIntType(), 1),
+ builder->getIntValue(builder->getUIntType(), 0)};
+ auto dstVal = builder->emitIntrinsicInst(
+ builder->getUIntType(),
+ kIROp_Select,
+ 3,
+ args);
auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
@@ -268,9 +282,9 @@ namespace Slang
advanceOffset(4);
break;
}
- case kIROp_UIntType:
+ case kIROp_UIntType:
#if SLANG_PTR_IS_32
- case kIROp_UIntPtrType:
+ case kIROp_UIntPtrType:
#endif
{
ensureOffsetAt4ByteBoundary();
@@ -286,9 +300,9 @@ namespace Slang
advanceOffset(4);
break;
}
- case kIROp_Int16Type:
- case kIROp_UInt16Type:
- case kIROp_HalfType:
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ case kIROp_HalfType:
{
ensureOffsetAt2ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -296,7 +310,8 @@ namespace Slang
auto srcVal = builder->emitLoad(concreteVar);
if (dataType->getOp() == kIROp_HalfType)
{
- srcVal = builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal);
+ srcVal =
+ builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal);
}
srcVal = builder->emitCast(builder->getType(kIROp_UIntType), srcVal);
auto dstAddr = builder->emitFieldAddress(
@@ -307,16 +322,19 @@ namespace Slang
if (intraFieldOffset == 0)
{
dstVal = builder->emitBitAnd(
- dstVal->getFullType(), dstVal,
+ dstVal->getFullType(),
+ dstVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF0000));
}
else
{
srcVal = builder->emitShl(
- srcVal->getFullType(), srcVal,
+ srcVal->getFullType(),
+ srcVal,
builder->getIntValue(builder->getUIntType(), 16));
dstVal = builder->emitBitAnd(
- dstVal->getFullType(), dstVal,
+ dstVal->getFullType(),
+ dstVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF));
}
dstVal = builder->emitBitOr(dstVal->getFullType(), dstVal, srcVal);
@@ -325,105 +343,111 @@ namespace Slang
advanceOffset(2);
break;
}
- case kIROp_Int8Type:
- case kIROp_UInt8Type:
- case kIROp_UInt64Type:
- case kIROp_Int64Type:
- case kIROp_DoubleType:
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt64Type:
+ case kIROp_Int64Type:
+ case kIROp_DoubleType:
#if SLANG_PTR_IS_64
- case kIROp_UIntPtrType:
- case kIROp_IntPtrType:
+ case kIROp_UIntPtrType:
+ case kIROp_IntPtrType:
#endif
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
- break;
- default:
- SLANG_UNREACHABLE("unknown basic type");
- }
+ SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ break;
+ default: SLANG_UNREACHABLE("unknown basic type");
}
+ }
- virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
+ virtual void marshalResourceHandle(
+ IRBuilder* builder,
+ IRType* dataType,
+ IRInst* concreteVar) override
+ {
+ SLANG_UNUSED(dataType);
+ ensureOffsetAt4ByteBoundary();
+ if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
- SLANG_UNUSED(dataType);
- ensureOffsetAt4ByteBoundary();
- if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
- {
- auto srcVal = builder->emitLoad(concreteVar);
- auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal);
- auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val);
- auto shiftedBits = builder->emitShr(
- builder->getUInt64Type(),
- uint64Val,
- builder->getIntValue(builder->getIntType(), 32));
- auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits);
- auto dstAddr1 = builder->emitFieldAddress(
- uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
- builder->emitStore(dstAddr1, lowBits);
- auto dstAddr2 = builder->emitFieldAddress(
- uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
- builder->emitStore(dstAddr2, highBits);
- advanceOffset(8);
- }
+ auto srcVal = builder->emitLoad(concreteVar);
+ auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal);
+ auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val);
+ auto shiftedBits = builder->emitShr(
+ builder->getUInt64Type(),
+ uint64Val,
+ builder->getIntValue(builder->getIntType(), 32));
+ auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits);
+ auto dstAddr1 = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr1, lowBits);
+ auto dstAddr2 = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset + 1]);
+ builder->emitStore(dstAddr2, highBits);
+ advanceOffset(8);
}
- };
+ }
+ };
- IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType)
- {
- IRBuilder builder(sharedContext->module);
- builder.setInsertBefore(type);
- auto anyValInfo = ensureAnyValueType(anyValueType);
+ IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType)
+ {
+ IRBuilder builder(sharedContext->module);
+ builder.setInsertBefore(type);
+ auto anyValInfo = ensureAnyValueType(anyValueType);
- auto func = builder.createFunc();
+ auto func = builder.createFunc();
- StringBuilder nameSb;
- nameSb << "packAnyValue" << getIntVal(anyValueType->getSize());
- builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
- // Currently we don't add linkage to the generated func, since we
- // do not have a way to compute mangled names from an IR entity.
- // This will leads to duplicate packing functions in linked code
- // but there won't be correctness issues.
+ StringBuilder nameSb;
+ nameSb << "packAnyValue" << getIntVal(anyValueType->getSize());
+ builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
+ // Currently we don't add linkage to the generated func, since we
+ // do not have a way to compute mangled names from an IR entity.
+ // This will leads to duplicate packing functions in linked code
+ // but there won't be correctness issues.
- auto funcType = builder.getFuncType(1, &type, anyValInfo->type);
- func->setFullType(funcType);
- builder.setInsertInto(func);
+ auto funcType = builder.getFuncType(1, &type, anyValInfo->type);
+ func->setFullType(funcType);
+ builder.setInsertInto(func);
- builder.emitBlock();
+ builder.emitBlock();
- auto param = builder.emitParam(type);
- auto concreteTypedVar = builder.emitVar(type);
- builder.emitStore(concreteTypedVar, param);
- auto resultVar = builder.emitVar(anyValInfo->type);
+ auto param = builder.emitParam(type);
+ auto concreteTypedVar = builder.emitVar(type);
+ builder.emitStore(concreteTypedVar, param);
+ auto resultVar = builder.emitVar(anyValInfo->type);
- // Initialize fields to 0 to prevent downstream compiler error.
- for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++)
- {
- auto fieldAddr = builder.emitFieldAddress(
- builder.getPtrType(builder.getUIntType()),
- resultVar,
- anyValInfo->fieldKeys[offset]
- );
- builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0));
- }
+ // Initialize fields to 0 to prevent downstream compiler error.
+ for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++)
+ {
+ auto fieldAddr = builder.emitFieldAddress(
+ builder.getPtrType(builder.getUIntType()),
+ resultVar,
+ anyValInfo->fieldKeys[offset]);
+ builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0));
+ }
- TypePackingContext context;
- context.anyValInfo = anyValInfo;
- context.fieldOffset = context.intraFieldOffset = 0;
- context.uintPtrType = builder.getPtrType(builder.getUIntType());
- context.anyValueVar = resultVar;
- emitMarshallingCode(&builder, &context, concreteTypedVar);
+ TypePackingContext context;
+ context.anyValInfo = anyValInfo;
+ context.fieldOffset = context.intraFieldOffset = 0;
+ context.uintPtrType = builder.getPtrType(builder.getUIntType());
+ context.anyValueVar = resultVar;
+ emitMarshallingCode(&builder, &context, concreteTypedVar);
- auto load = builder.emitLoad(resultVar);
- builder.emitReturn(load);
- return func;
- }
+ auto load = builder.emitLoad(resultVar);
+ builder.emitReturn(load);
+ return func;
+ }
- struct TypeUnpackingContext : TypeMarshallingContext
+ struct TypeUnpackingContext : TypeMarshallingContext
+ {
+ virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar)
+ override
{
- virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
+ switch (dataType->getOp())
{
- switch (dataType->getOp())
- {
- case kIROp_IntType:
- case kIROp_FloatType:
+ case kIROp_IntType:
+ case kIROp_FloatType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -439,7 +463,7 @@ namespace Slang
advanceOffset(4);
break;
}
- case kIROp_BoolType:
+ case kIROp_BoolType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -449,13 +473,15 @@ namespace Slang
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto srcVal = builder->emitLoad(srcAddr);
- srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0));
+ srcVal = builder->emitNeq(
+ srcVal,
+ builder->getIntValue(builder->getUIntType(), 0));
builder->emitStore(concreteVar, srcVal);
}
advanceOffset(4);
break;
}
- case kIROp_UIntType:
+ case kIROp_UIntType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -470,9 +496,9 @@ namespace Slang
advanceOffset(4);
break;
}
- case kIROp_Int16Type:
- case kIROp_UInt16Type:
- case kIROp_HalfType:
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ case kIROp_HalfType:
{
ensureOffsetAt2ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -485,13 +511,15 @@ namespace Slang
if (intraFieldOffset == 0)
{
srcVal = builder->emitBitAnd(
- srcVal->getFullType(), srcVal,
+ srcVal->getFullType(),
+ srcVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF));
}
else
{
srcVal = builder->emitShr(
- srcVal->getFullType(), srcVal,
+ srcVal->getFullType(),
+ srcVal,
builder->getIntValue(builder->getUIntType(), 16));
}
if (dataType->getOp() == kIROp_Int16Type)
@@ -511,203 +539,205 @@ namespace Slang
advanceOffset(2);
break;
}
- case kIROp_UInt64Type:
- case kIROp_Int64Type:
- case kIROp_DoubleType:
- case kIROp_Int8Type:
- case kIROp_UInt8Type:
- SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
- break;
- default:
- SLANG_UNREACHABLE("unknown basic type");
- }
+ case kIROp_UInt64Type:
+ case kIROp_Int64Type:
+ case kIROp_DoubleType:
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
+ break;
+ default: SLANG_UNREACHABLE("unknown basic type");
}
+ }
- virtual void marshalResourceHandle(
- IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
+ virtual void marshalResourceHandle(
+ IRBuilder* builder,
+ IRType* dataType,
+ IRInst* concreteVar) override
+ {
+ ensureOffsetAt4ByteBoundary();
+ if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
- ensureOffsetAt4ByteBoundary();
- if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
- {
- auto srcAddr = builder->emitFieldAddress(
- uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
- auto lowBits = builder->emitLoad(srcAddr);
-
- auto srcAddr1 = builder->emitFieldAddress(
- uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
- auto highBits = builder->emitLoad(srcAddr1);
-
- auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
- combinedBits = builder->emitBitCast(dataType, combinedBits);
- builder->emitStore(concreteVar, combinedBits);
- advanceOffset(8);
- }
+ auto srcAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ auto lowBits = builder->emitLoad(srcAddr);
+
+ auto srcAddr1 = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset + 1]);
+ auto highBits = builder->emitLoad(srcAddr1);
+
+ auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
+ combinedBits = builder->emitBitCast(dataType, combinedBits);
+ builder->emitStore(concreteVar, combinedBits);
+ advanceOffset(8);
}
- };
-
- IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType)
- {
- IRBuilder builder(sharedContext->module);
- builder.setInsertBefore(type);
- auto anyValInfo = ensureAnyValueType(anyValueType);
-
- auto func = builder.createFunc();
-
- StringBuilder nameSb;
- nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize());
- builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
-
- auto funcType = builder.getFuncType(1, &anyValInfo->type, type);
- func->setFullType(funcType);
- builder.setInsertInto(func);
-
- builder.emitBlock();
-
- auto param = builder.emitParam(anyValInfo->type);
- auto anyValueVar = builder.emitVar(anyValInfo->type);
- builder.emitStore(anyValueVar, param);
- auto resultVar = builder.emitVar(type);
-
- TypeUnpackingContext context;
- context.anyValInfo = anyValInfo;
- context.fieldOffset = context.intraFieldOffset = 0;
- context.uintPtrType = builder.getPtrType(builder.getUIntType());
- context.anyValueVar = anyValueVar;
- emitMarshallingCode(&builder, &context, resultVar);
- auto load = builder.emitLoad(resultVar);
- builder.emitReturn(load);
- return func;
}
+ };
- // Ensures the marshalling functions between `type` and `anyValueType` are already generated.
- // Returns the generated marshalling functions.
- MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType)
- {
- auto size = getIntVal(anyValueType->getSize());
- MarshallingFunctionKey key;
- key.originalType = type;
- key.anyValueSize = size;
- MarshallingFunctionSet funcSet;
- if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet))
- return funcSet;
- funcSet.packFunc = generatePackingFunc(type, anyValueType);
- funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType);
- mapTypeMarshalingFunctions[key] = funcSet;
+ IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType)
+ {
+ IRBuilder builder(sharedContext->module);
+ builder.setInsertBefore(type);
+ auto anyValInfo = ensureAnyValueType(anyValueType);
+
+ auto func = builder.createFunc();
+
+ StringBuilder nameSb;
+ nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize());
+ builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
+
+ auto funcType = builder.getFuncType(1, &anyValInfo->type, type);
+ func->setFullType(funcType);
+ builder.setInsertInto(func);
+
+ builder.emitBlock();
+
+ auto param = builder.emitParam(anyValInfo->type);
+ auto anyValueVar = builder.emitVar(anyValInfo->type);
+ builder.emitStore(anyValueVar, param);
+ auto resultVar = builder.emitVar(type);
+
+ TypeUnpackingContext context;
+ context.anyValInfo = anyValInfo;
+ context.fieldOffset = context.intraFieldOffset = 0;
+ context.uintPtrType = builder.getPtrType(builder.getUIntType());
+ context.anyValueVar = anyValueVar;
+ emitMarshallingCode(&builder, &context, resultVar);
+ auto load = builder.emitLoad(resultVar);
+ builder.emitReturn(load);
+ return func;
+ }
+
+ // Ensures the marshalling functions between `type` and `anyValueType` are already generated.
+ // Returns the generated marshalling functions.
+ MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType)
+ {
+ auto size = getIntVal(anyValueType->getSize());
+ MarshallingFunctionKey key;
+ key.originalType = type;
+ key.anyValueSize = size;
+ MarshallingFunctionSet funcSet;
+ if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet))
return funcSet;
- }
+ funcSet.packFunc = generatePackingFunc(type, anyValueType);
+ funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType);
+ mapTypeMarshalingFunctions[key] = funcSet;
+ return funcSet;
+ }
- void processPackInst(IRPackAnyValue* packInst)
- {
- auto operand = packInst->getValue();
- auto func = ensureMarshallingFunc(
- operand->getDataType(),
- cast<IRAnyValueType>(packInst->getDataType()));
- IRBuilder builderStorage(sharedContext->module);
- auto builder = &builderStorage;
- builder->setInsertBefore(packInst);
- auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand);
- packInst->replaceUsesWith(callInst);
- packInst->removeAndDeallocate();
- }
+ void processPackInst(IRPackAnyValue* packInst)
+ {
+ auto operand = packInst->getValue();
+ auto func = ensureMarshallingFunc(
+ operand->getDataType(),
+ cast<IRAnyValueType>(packInst->getDataType()));
+ IRBuilder builderStorage(sharedContext->module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(packInst);
+ auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand);
+ packInst->replaceUsesWith(callInst);
+ packInst->removeAndDeallocate();
+ }
- void processUnpackInst(IRUnpackAnyValue* unpackInst)
- {
- auto operand = unpackInst->getValue();
- auto func = ensureMarshallingFunc(
- unpackInst->getDataType(),
- cast<IRAnyValueType>(operand->getDataType()));
- IRBuilder builderStorage(sharedContext->module);
- auto builder = &builderStorage;
- builder->setInsertBefore(unpackInst);
- auto callInst = builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand);
- unpackInst->replaceUsesWith(callInst);
- unpackInst->removeAndDeallocate();
- }
+ void processUnpackInst(IRUnpackAnyValue* unpackInst)
+ {
+ auto operand = unpackInst->getValue();
+ auto func = ensureMarshallingFunc(
+ unpackInst->getDataType(),
+ cast<IRAnyValueType>(operand->getDataType()));
+ IRBuilder builderStorage(sharedContext->module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(unpackInst);
+ auto callInst =
+ builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand);
+ unpackInst->replaceUsesWith(callInst);
+ unpackInst->removeAndDeallocate();
+ }
- void processAnyValueType(IRAnyValueType* type)
- {
- auto info = ensureAnyValueType(type);
- type->replaceUsesWith(info->type);
- }
+ void processAnyValueType(IRAnyValueType* type)
+ {
+ auto info = ensureAnyValueType(type);
+ type->replaceUsesWith(info->type);
+ }
- void processInst(IRInst* inst)
+ void processInst(IRInst* inst)
+ {
+ if (auto packInst = as<IRPackAnyValue>(inst))
{
- if (auto packInst = as<IRPackAnyValue>(inst))
- {
- processPackInst(packInst);
- }
- else if (auto unpackInst = as<IRUnpackAnyValue>(inst))
- {
- processUnpackInst(unpackInst);
- }
+ processPackInst(packInst);
}
-
- void processModule()
+ else if (auto unpackInst = as<IRUnpackAnyValue>(inst))
{
- // We start by initializing our shared IR building state,
- // since we will re-use that state for any code we
- // generate along the way.
- //
- sharedContext->addToWorkList(sharedContext->module->getModuleInst());
+ processUnpackInst(unpackInst);
+ }
+ }
- while (sharedContext->workList.getCount() != 0)
- {
- IRInst* inst = sharedContext->workList.getLast();
+ void processModule()
+ {
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ sharedContext->addToWorkList(sharedContext->module->getModuleInst());
- sharedContext->workList.removeLast();
- sharedContext->workListSet.remove(inst);
+ while (sharedContext->workList.getCount() != 0)
+ {
+ IRInst* inst = sharedContext->workList.getLast();
- processInst(inst);
+ sharedContext->workList.removeLast();
+ sharedContext->workListSet.remove(inst);
- for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
- {
- sharedContext->addToWorkList(child);
- }
- }
+ processInst(inst);
- // Finally, replace all `AnyValueType` with the actual struct type that implements it.
- for (auto inst : sharedContext->module->getModuleInst()->getChildren())
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
{
- if (auto anyValueType = as<IRAnyValueType>(inst))
- processAnyValueType(anyValueType);
+ sharedContext->addToWorkList(child);
}
- sharedContext->mapInterfaceRequirementKeyValue.clear();
}
- };
- void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext)
- {
- AnyValueMarshallingContext context;
- context.sharedContext = sharedContext;
- context.processModule();
+ // Finally, replace all `AnyValueType` with the actual struct type that implements it.
+ for (auto inst : sharedContext->module->getModuleInst()->getChildren())
+ {
+ if (auto anyValueType = as<IRAnyValueType>(inst))
+ processAnyValueType(anyValueType);
+ }
+ sharedContext->mapInterfaceRequirementKeyValue.clear();
}
+};
- SlangInt alignUp(SlangInt x, SlangInt alignment)
- {
- return (x + alignment - 1) / alignment * alignment;
- }
+void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext)
+{
+ AnyValueMarshallingContext context;
+ context.sharedContext = sharedContext;
+ context.processModule();
+}
- SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
+SlangInt alignUp(SlangInt x, SlangInt alignment)
+{
+ return (x + alignment - 1) / alignment * alignment;
+}
+
+SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
+{
+ switch (type->getOp())
{
- switch (type->getOp())
- {
- case kIROp_IntType:
- case kIROp_FloatType:
- case kIROp_UIntType:
- case kIROp_BoolType:
- return alignUp(offset, 4) + 4;
- case kIROp_UInt64Type:
- case kIROp_Int64Type:
- case kIROp_DoubleType:
- return -1;
- case kIROp_Int16Type:
- case kIROp_UInt16Type:
- case kIROp_HalfType:
- return alignUp(offset, 2) + 2;
- case kIROp_UInt8Type:
- case kIROp_Int8Type:
- return -1;
- case kIROp_VectorType:
+ case kIROp_IntType:
+ case kIROp_FloatType:
+ case kIROp_UIntType:
+ case kIROp_BoolType: return alignUp(offset, 4) + 4;
+ case kIROp_UInt64Type:
+ case kIROp_Int64Type:
+ case kIROp_DoubleType: return -1;
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ case kIROp_HalfType: return alignUp(offset, 2) + 2;
+ case kIROp_UInt8Type:
+ case kIROp_Int8Type: return -1;
+ case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(type);
auto elementType = vectorType->getElementType();
@@ -715,11 +745,12 @@ namespace Slang
for (IRIntegerValue i = 0; i < elementCount; i++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
- if (offset < 0) return offset;
+ if (offset < 0)
+ return offset;
}
return offset;
}
- case kIROp_MatrixType:
+ case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(type);
auto elementType = matrixType->getElementType();
@@ -730,83 +761,95 @@ namespace Slang
for (IRIntegerValue j = 0; j < rowCount; j++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
- if (offset < 0) return offset;
+ if (offset < 0)
+ return offset;
}
}
return offset;
}
- case kIROp_StructType:
+ case kIROp_StructType:
{
auto structType = cast<IRStructType>(type);
for (auto field : structType->getFields())
{
offset = _getAnyValueSizeRaw(field->getFieldType(), offset);
- if (offset < 0) return offset;
+ if (offset < 0)
+ return offset;
}
return offset;
}
- case kIROp_ArrayType:
+ case kIROp_ArrayType:
{
auto arrayType = cast<IRArrayType>(type);
for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
{
offset = _getAnyValueSizeRaw(arrayType->getElementType(), offset);
- if (offset < 0) return offset;
+ if (offset < 0)
+ return offset;
}
return offset;
}
- case kIROp_AnyValueType:
+ case kIROp_AnyValueType:
{
auto anyValueType = cast<IRAnyValueType>(type);
return alignUp(offset, 4) + (SlangInt)getIntVal(anyValueType->getSize());
}
- case kIROp_TupleType:
+ case kIROp_TupleType:
{
auto tupleType = cast<IRTupleType>(type);
for (UInt i = 0; i < tupleType->getOperandCount(); i++)
{
auto elementType = tupleType->getOperand(i);
offset = _getAnyValueSizeRaw((IRType*)elementType, offset);
- if (offset < 0) return offset;
+ if (offset < 0)
+ return offset;
}
return offset;
}
- case kIROp_WitnessTableType:
- case kIROp_WitnessTableIDType:
- case kIROp_RTTIHandleType:
+ case kIROp_WitnessTableType:
+ case kIROp_WitnessTableIDType:
+ case kIROp_RTTIHandleType:
{
return alignUp(offset, 4) + kRTTIHandleSize;
}
- case kIROp_InterfaceType:
+ case kIROp_InterfaceType:
{
auto interfaceType = cast<IRInterfaceType>(type);
- auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
+ auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(
+ interfaceType,
+ interfaceType->sourceLoc);
size += kRTTIHeaderSize;
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
- case kIROp_AssociatedType:
+ case kIROp_AssociatedType:
{
auto associatedType = cast<IRAssociatedType>(type);
SlangInt maxSize = 0;
for (UInt i = 0; i < associatedType->getOperandCount(); i++)
- maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset));
+ maxSize = Math::Max(
+ maxSize,
+ _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset));
return maxSize;
}
- case kIROp_ThisType:
+ case kIROp_ThisType:
{
auto thisType = cast<IRThisType>(type);
auto interfaceType = thisType->getConstraintType();
- auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
+ auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(
+ interfaceType,
+ interfaceType->sourceLoc);
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
- case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialType:
{
auto existentialValue = type->getOperand(0);
auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType());
- auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
+ auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(
+ interfaceType,
+ interfaceType->sourceLoc);
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
- case kIROp_LookupWitness:
+ case kIROp_LookupWitness:
{
auto witnessTableVal = type->getOperand(0);
auto key = type->getOperand(1);
@@ -815,13 +858,13 @@ namespace Slang
{
auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType());
- // Walk through interface operands to find a match, the result should be an
+ // Walk through interface operands to find a match, the result should be an
// associated type entry.
//
for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++)
{
auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii));
- if (entry->getRequirementKey() == key &&
+ if (entry->getRequirementKey() == key &&
as<IRAssociatedType>(entry->getRequirementVal()))
{
assocType = (IRType*)entry->getRequirementVal();
@@ -832,33 +875,36 @@ namespace Slang
if (!assocType)
return -1;
-
+
IRIntegerValue anyValueSize = kInvalidAnyValueSize;
for (UInt i = 0; i < assocType->getOperandCount(); i++)
{
anyValueSize = Math::Min(
anyValueSize,
- SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc));
+ SharedGenericsLoweringContext::getInterfaceAnyValueSize(
+ assocType->getOperand(i),
+ type->sourceLoc));
}
if (anyValueSize == kInvalidAnyValueSize)
return -1;
-
+
return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4);
}
- default:
- if (isResourceType(type))
- {
- return alignUp(offset, 4) + 8;
- }
- return -1;
+ default:
+ if (isResourceType(type))
+ {
+ return alignUp(offset, 4) + 8;
}
+ return -1;
}
+}
- SlangInt getAnyValueSize(IRType* type)
- {
- auto rawSize = _getAnyValueSizeRaw(type, 0);
- if (rawSize < 0) return rawSize;
- return alignUp(rawSize, 4);
- }
+SlangInt getAnyValueSize(IRType* type)
+{
+ auto rawSize = _getAnyValueSizeRaw(type, 0);
+ if (rawSize < 0)
+ return rawSize;
+ return alignUp(rawSize, 4);
}
+} // namespace Slang