diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-any-value-marshalling.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-any-value-marshalling.cpp')
| -rw-r--r-- | source/slang/slang-ir-any-value-marshalling.cpp | 938 |
1 files changed, 492 insertions, 446 deletions
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index f09294aa9..b51ab4608 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -2,147 +2,153 @@ #include "../core/slang-math.h" #include "slang-ir-generics-lowering-context.h" -#include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir.h" #include "slang-legalize-types.h" namespace Slang { - // This is a subpass of generics lowering IR transformation. - // This pass generates packing/unpacking functions for `AnyValue`s, - // and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these - // functions. - struct AnyValueMarshallingContext - { - SharedGenericsLoweringContext* sharedContext; +// This is a subpass of generics lowering IR transformation. +// This pass generates packing/unpacking functions for `AnyValue`s, +// and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these +// functions. +struct AnyValueMarshallingContext +{ + SharedGenericsLoweringContext* sharedContext; - // Stores information about generated `AnyValue` struct types. - struct AnyValueTypeInfo : RefObject - { - IRType* type; // The generated IR value for the `AnyValue<N>` struct type. - List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type. - }; + // Stores information about generated `AnyValue` struct types. + struct AnyValueTypeInfo : RefObject + { + IRType* type; // The generated IR value for the `AnyValue<N>` struct type. + List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type. + }; - Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes; + Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes; - struct MarshallingFunctionKey + struct MarshallingFunctionKey + { + IRType* originalType; + IRIntegerValue anyValueSize; + bool operator==(MarshallingFunctionKey other) const { - IRType* originalType; - IRIntegerValue anyValueSize; - bool operator ==(MarshallingFunctionKey other) const - { - return originalType == other.originalType && anyValueSize == other.anyValueSize; - } - HashCode getHashCode() const - { - return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize)); - } - }; - - struct MarshallingFunctionSet + return originalType == other.originalType && anyValueSize == other.anyValueSize; + } + HashCode getHashCode() const { - IRFunc* packFunc; - IRFunc* unpackFunc; - }; + return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize)); + } + }; - // Stores the generated packing/unpacking functions for lookup. - Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions; + struct MarshallingFunctionSet + { + IRFunc* packFunc; + IRFunc* unpackFunc; + }; - AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type) + // Stores the generated packing/unpacking functions for lookup. + Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions; + + AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type) + { + auto size = getIntVal(type->getSize()); + if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size)) + return typeInfo->Ptr(); + RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo(); + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(type); + auto structType = builder.createStructType(); + info->type = structType; + StringBuilder nameSb; + nameSb << "AnyValue" << size; + builder.addExportDecoration(structType, nameSb.getUnownedSlice()); + auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t); + for (decltype(fieldCount) i = 0; i < fieldCount; i++) { - auto size = getIntVal(type->getSize()); - if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size)) - return typeInfo->Ptr(); - RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo(); - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(type); - auto structType = builder.createStructType(); - info->type = structType; - StringBuilder nameSb; - nameSb << "AnyValue" << size; - builder.addExportDecoration(structType, nameSb.getUnownedSlice()); - auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t); - for (decltype(fieldCount) i = 0; i < fieldCount; i++) - { - auto key = builder.createStructKey(); - nameSb.clear(); - nameSb << "field" << i; - builder.addNameHintDecoration(key, nameSb.getUnownedSlice()); - nameSb << "_anyVal" << size; - builder.addExportDecoration(key, nameSb.getUnownedSlice()); - builder.createStructField(structType, key, builder.getUIntType()); - info->fieldKeys.add(key); - } - generatedAnyValueTypes[size] = info; - return info.Ptr(); + auto key = builder.createStructKey(); + nameSb.clear(); + nameSb << "field" << i; + builder.addNameHintDecoration(key, nameSb.getUnownedSlice()); + nameSb << "_anyVal" << size; + builder.addExportDecoration(key, nameSb.getUnownedSlice()); + builder.createStructField(structType, key, builder.getUIntType()); + info->fieldKeys.add(key); } + generatedAnyValueTypes[size] = info; + return info.Ptr(); + } - struct TypeMarshallingContext + struct TypeMarshallingContext + { + AnyValueTypeInfo* anyValInfo; + uint32_t fieldOffset; + uint32_t intraFieldOffset; + IRType* uintPtrType; + IRInst* anyValueVar; + // Defines what to do with basic typed data elements. + virtual void marshalBasicType( + IRBuilder* builder, + IRType* dataType, + IRInst* concreteTypedVar) = 0; + // Defines what to do with resource handle elements. + virtual void marshalResourceHandle( + IRBuilder* builder, + IRType* dataType, + IRInst* concreteTypedVar) = 0; + + void ensureOffsetAt4ByteBoundary() { - AnyValueTypeInfo* anyValInfo; - uint32_t fieldOffset; - uint32_t intraFieldOffset; - IRType* uintPtrType; - IRInst* anyValueVar; - // Defines what to do with basic typed data elements. - virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0; - // Defines what to do with resource handle elements. - virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0; - - void ensureOffsetAt4ByteBoundary() - { - if (intraFieldOffset) - { - fieldOffset++; - intraFieldOffset = 0; - } - } - void ensureOffsetAt2ByteBoundary() + if (intraFieldOffset) { - if (intraFieldOffset == 0) - return; - if (intraFieldOffset <= 2) - { - intraFieldOffset = 2; - return; - } fieldOffset++; intraFieldOffset = 0; - return; } - void advanceOffset(uint32_t bytes) + } + void ensureOffsetAt2ByteBoundary() + { + if (intraFieldOffset == 0) + return; + if (intraFieldOffset <= 2) { - intraFieldOffset += bytes; - fieldOffset += intraFieldOffset / 4; - intraFieldOffset = intraFieldOffset % 4; + intraFieldOffset = 2; + return; } - }; + fieldOffset++; + intraFieldOffset = 0; + return; + } + void advanceOffset(uint32_t bytes) + { + intraFieldOffset += bytes; + fieldOffset += intraFieldOffset / 4; + intraFieldOffset = intraFieldOffset % 4; + } + }; - void emitMarshallingCode( - IRBuilder* builder, - TypeMarshallingContext* context, - IRInst* concreteTypedVar) + void emitMarshallingCode( + IRBuilder* builder, + TypeMarshallingContext* context, + IRInst* concreteTypedVar) + { + auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType(); + switch (dataType->getOp()) { - auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType(); - switch (dataType->getOp()) - { - case kIROp_IntType: - case kIROp_FloatType: - case kIROp_UIntType: - case kIROp_UInt64Type: - case kIROp_Int64Type: - case kIROp_DoubleType: - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_HalfType: - case kIROp_BoolType: - case kIROp_IntPtrType: - case kIROp_UIntPtrType: - context->marshalBasicType(builder, dataType, concreteTypedVar); - break; - case kIROp_VectorType: + case kIROp_IntType: + case kIROp_FloatType: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_Int64Type: + case kIROp_DoubleType: + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_HalfType: + case kIROp_BoolType: + case kIROp_IntPtrType: + case kIROp_UIntPtrType: + context->marshalBasicType(builder, dataType, concreteTypedVar); + break; + case kIROp_VectorType: { auto vectorType = static_cast<IRVectorType*>(dataType); auto elementCount = getIntVal(vectorType->getElementCount()); @@ -155,7 +161,7 @@ namespace Slang } break; } - case kIROp_MatrixType: + case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(dataType); auto colCount = getIntVal(matrixType->getColumnCount()); @@ -175,7 +181,7 @@ namespace Slang } break; } - case kIROp_StructType: + case kIROp_StructType: { auto structType = cast<IRStructType>(dataType); for (auto field : structType->getFields()) @@ -188,7 +194,7 @@ namespace Slang } break; } - case kIROp_ArrayType: + case kIROp_ArrayType: { auto arrayType = cast<IRArrayType>(dataType); for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++) @@ -200,7 +206,7 @@ namespace Slang } break; } - case kIROp_AnyValueType: + case kIROp_AnyValueType: { auto anyValType = cast<IRAnyValueType>(dataType); auto info = ensureAnyValueType(anyValType); @@ -214,27 +220,28 @@ namespace Slang } break; } - default: - if (isResourceType(dataType)) - { - context->marshalResourceHandle(builder, dataType, concreteTypedVar); - return; - } - SLANG_UNIMPLEMENTED_X("Unimplemented type packing"); - break; + default: + if (isResourceType(dataType)) + { + context->marshalResourceHandle(builder, dataType, concreteTypedVar); + return; } + SLANG_UNIMPLEMENTED_X("Unimplemented type packing"); + break; } + } - struct TypePackingContext : TypeMarshallingContext + struct TypePackingContext : TypeMarshallingContext + { + virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) + override { - virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override + switch (dataType->getOp()) { - switch (dataType->getOp()) - { - case kIROp_IntType: - case kIROp_FloatType: + case kIROp_IntType: + case kIROp_FloatType: #if SLANG_PTR_IS_32 - case kIROp_IntPtrType: + case kIROp_IntPtrType: #endif { ensureOffsetAt4ByteBoundary(); @@ -251,14 +258,21 @@ namespace Slang advanceOffset(4); break; } - case kIROp_BoolType: + case kIROp_BoolType: { ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) { auto srcVal = builder->emitLoad(concreteVar); - IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) }; - auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args); + IRInst* args[] = { + srcVal, + builder->getIntValue(builder->getUIntType(), 1), + builder->getIntValue(builder->getUIntType(), 0)}; + auto dstVal = builder->emitIntrinsicInst( + builder->getUIntType(), + kIROp_Select, + 3, + args); auto dstAddr = builder->emitFieldAddress( uintPtrType, anyValueVar, @@ -268,9 +282,9 @@ namespace Slang advanceOffset(4); break; } - case kIROp_UIntType: + case kIROp_UIntType: #if SLANG_PTR_IS_32 - case kIROp_UIntPtrType: + case kIROp_UIntPtrType: #endif { ensureOffsetAt4ByteBoundary(); @@ -286,9 +300,9 @@ namespace Slang advanceOffset(4); break; } - case kIROp_Int16Type: - case kIROp_UInt16Type: - case kIROp_HalfType: + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: { ensureOffsetAt2ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -296,7 +310,8 @@ namespace Slang auto srcVal = builder->emitLoad(concreteVar); if (dataType->getOp() == kIROp_HalfType) { - srcVal = builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal); + srcVal = + builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal); } srcVal = builder->emitCast(builder->getType(kIROp_UIntType), srcVal); auto dstAddr = builder->emitFieldAddress( @@ -307,16 +322,19 @@ namespace Slang if (intraFieldOffset == 0) { dstVal = builder->emitBitAnd( - dstVal->getFullType(), dstVal, + dstVal->getFullType(), + dstVal, builder->getIntValue(builder->getUIntType(), 0xFFFF0000)); } else { srcVal = builder->emitShl( - srcVal->getFullType(), srcVal, + srcVal->getFullType(), + srcVal, builder->getIntValue(builder->getUIntType(), 16)); dstVal = builder->emitBitAnd( - dstVal->getFullType(), dstVal, + dstVal->getFullType(), + dstVal, builder->getIntValue(builder->getUIntType(), 0xFFFF)); } dstVal = builder->emitBitOr(dstVal->getFullType(), dstVal, srcVal); @@ -325,105 +343,111 @@ namespace Slang advanceOffset(2); break; } - case kIROp_Int8Type: - case kIROp_UInt8Type: - case kIROp_UInt64Type: - case kIROp_Int64Type: - case kIROp_DoubleType: + case kIROp_Int8Type: + case kIROp_UInt8Type: + case kIROp_UInt64Type: + case kIROp_Int64Type: + case kIROp_DoubleType: #if SLANG_PTR_IS_64 - case kIROp_UIntPtrType: - case kIROp_IntPtrType: + case kIROp_UIntPtrType: + case kIROp_IntPtrType: #endif - SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); - break; - default: - SLANG_UNREACHABLE("unknown basic type"); - } + SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); + break; + default: SLANG_UNREACHABLE("unknown basic type"); } + } - virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override + virtual void marshalResourceHandle( + IRBuilder* builder, + IRType* dataType, + IRInst* concreteVar) override + { + SLANG_UNUSED(dataType); + ensureOffsetAt4ByteBoundary(); + if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) { - SLANG_UNUSED(dataType); - ensureOffsetAt4ByteBoundary(); - if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) - { - auto srcVal = builder->emitLoad(concreteVar); - auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal); - auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val); - auto shiftedBits = builder->emitShr( - builder->getUInt64Type(), - uint64Val, - builder->getIntValue(builder->getIntType(), 32)); - auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits); - auto dstAddr1 = builder->emitFieldAddress( - uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]); - builder->emitStore(dstAddr1, lowBits); - auto dstAddr2 = builder->emitFieldAddress( - uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]); - builder->emitStore(dstAddr2, highBits); - advanceOffset(8); - } + auto srcVal = builder->emitLoad(concreteVar); + auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal); + auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val); + auto shiftedBits = builder->emitShr( + builder->getUInt64Type(), + uint64Val, + builder->getIntValue(builder->getIntType(), 32)); + auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits); + auto dstAddr1 = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + builder->emitStore(dstAddr1, lowBits); + auto dstAddr2 = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset + 1]); + builder->emitStore(dstAddr2, highBits); + advanceOffset(8); } - }; + } + }; - IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(type); - auto anyValInfo = ensureAnyValueType(anyValueType); + IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType) + { + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(type); + auto anyValInfo = ensureAnyValueType(anyValueType); - auto func = builder.createFunc(); + auto func = builder.createFunc(); - StringBuilder nameSb; - nameSb << "packAnyValue" << getIntVal(anyValueType->getSize()); - builder.addNameHintDecoration(func, nameSb.getUnownedSlice()); - // Currently we don't add linkage to the generated func, since we - // do not have a way to compute mangled names from an IR entity. - // This will leads to duplicate packing functions in linked code - // but there won't be correctness issues. + StringBuilder nameSb; + nameSb << "packAnyValue" << getIntVal(anyValueType->getSize()); + builder.addNameHintDecoration(func, nameSb.getUnownedSlice()); + // Currently we don't add linkage to the generated func, since we + // do not have a way to compute mangled names from an IR entity. + // This will leads to duplicate packing functions in linked code + // but there won't be correctness issues. - auto funcType = builder.getFuncType(1, &type, anyValInfo->type); - func->setFullType(funcType); - builder.setInsertInto(func); + auto funcType = builder.getFuncType(1, &type, anyValInfo->type); + func->setFullType(funcType); + builder.setInsertInto(func); - builder.emitBlock(); + builder.emitBlock(); - auto param = builder.emitParam(type); - auto concreteTypedVar = builder.emitVar(type); - builder.emitStore(concreteTypedVar, param); - auto resultVar = builder.emitVar(anyValInfo->type); + auto param = builder.emitParam(type); + auto concreteTypedVar = builder.emitVar(type); + builder.emitStore(concreteTypedVar, param); + auto resultVar = builder.emitVar(anyValInfo->type); - // Initialize fields to 0 to prevent downstream compiler error. - for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++) - { - auto fieldAddr = builder.emitFieldAddress( - builder.getPtrType(builder.getUIntType()), - resultVar, - anyValInfo->fieldKeys[offset] - ); - builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0)); - } + // Initialize fields to 0 to prevent downstream compiler error. + for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++) + { + auto fieldAddr = builder.emitFieldAddress( + builder.getPtrType(builder.getUIntType()), + resultVar, + anyValInfo->fieldKeys[offset]); + builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0)); + } - TypePackingContext context; - context.anyValInfo = anyValInfo; - context.fieldOffset = context.intraFieldOffset = 0; - context.uintPtrType = builder.getPtrType(builder.getUIntType()); - context.anyValueVar = resultVar; - emitMarshallingCode(&builder, &context, concreteTypedVar); + TypePackingContext context; + context.anyValInfo = anyValInfo; + context.fieldOffset = context.intraFieldOffset = 0; + context.uintPtrType = builder.getPtrType(builder.getUIntType()); + context.anyValueVar = resultVar; + emitMarshallingCode(&builder, &context, concreteTypedVar); - auto load = builder.emitLoad(resultVar); - builder.emitReturn(load); - return func; - } + auto load = builder.emitLoad(resultVar); + builder.emitReturn(load); + return func; + } - struct TypeUnpackingContext : TypeMarshallingContext + struct TypeUnpackingContext : TypeMarshallingContext + { + virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) + override { - virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override + switch (dataType->getOp()) { - switch (dataType->getOp()) - { - case kIROp_IntType: - case kIROp_FloatType: + case kIROp_IntType: + case kIROp_FloatType: { ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -439,7 +463,7 @@ namespace Slang advanceOffset(4); break; } - case kIROp_BoolType: + case kIROp_BoolType: { ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -449,13 +473,15 @@ namespace Slang anyValueVar, anyValInfo->fieldKeys[fieldOffset]); auto srcVal = builder->emitLoad(srcAddr); - srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0)); + srcVal = builder->emitNeq( + srcVal, + builder->getIntValue(builder->getUIntType(), 0)); builder->emitStore(concreteVar, srcVal); } advanceOffset(4); break; } - case kIROp_UIntType: + case kIROp_UIntType: { ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -470,9 +496,9 @@ namespace Slang advanceOffset(4); break; } - case kIROp_Int16Type: - case kIROp_UInt16Type: - case kIROp_HalfType: + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: { ensureOffsetAt2ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -485,13 +511,15 @@ namespace Slang if (intraFieldOffset == 0) { srcVal = builder->emitBitAnd( - srcVal->getFullType(), srcVal, + srcVal->getFullType(), + srcVal, builder->getIntValue(builder->getUIntType(), 0xFFFF)); } else { srcVal = builder->emitShr( - srcVal->getFullType(), srcVal, + srcVal->getFullType(), + srcVal, builder->getIntValue(builder->getUIntType(), 16)); } if (dataType->getOp() == kIROp_Int16Type) @@ -511,203 +539,205 @@ namespace Slang advanceOffset(2); break; } - case kIROp_UInt64Type: - case kIROp_Int64Type: - case kIROp_DoubleType: - case kIROp_Int8Type: - case kIROp_UInt8Type: - SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); - break; - default: - SLANG_UNREACHABLE("unknown basic type"); - } + case kIROp_UInt64Type: + case kIROp_Int64Type: + case kIROp_DoubleType: + case kIROp_Int8Type: + case kIROp_UInt8Type: + SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); + break; + default: SLANG_UNREACHABLE("unknown basic type"); } + } - virtual void marshalResourceHandle( - IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override + virtual void marshalResourceHandle( + IRBuilder* builder, + IRType* dataType, + IRInst* concreteVar) override + { + ensureOffsetAt4ByteBoundary(); + if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) { - ensureOffsetAt4ByteBoundary(); - if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) - { - auto srcAddr = builder->emitFieldAddress( - uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]); - auto lowBits = builder->emitLoad(srcAddr); - - auto srcAddr1 = builder->emitFieldAddress( - uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]); - auto highBits = builder->emitLoad(srcAddr1); - - auto combinedBits = builder->emitMakeUInt64(lowBits, highBits); - combinedBits = builder->emitBitCast(dataType, combinedBits); - builder->emitStore(concreteVar, combinedBits); - advanceOffset(8); - } + auto srcAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + auto lowBits = builder->emitLoad(srcAddr); + + auto srcAddr1 = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset + 1]); + auto highBits = builder->emitLoad(srcAddr1); + + auto combinedBits = builder->emitMakeUInt64(lowBits, highBits); + combinedBits = builder->emitBitCast(dataType, combinedBits); + builder->emitStore(concreteVar, combinedBits); + advanceOffset(8); } - }; - - IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(type); - auto anyValInfo = ensureAnyValueType(anyValueType); - - auto func = builder.createFunc(); - - StringBuilder nameSb; - nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize()); - builder.addNameHintDecoration(func, nameSb.getUnownedSlice()); - - auto funcType = builder.getFuncType(1, &anyValInfo->type, type); - func->setFullType(funcType); - builder.setInsertInto(func); - - builder.emitBlock(); - - auto param = builder.emitParam(anyValInfo->type); - auto anyValueVar = builder.emitVar(anyValInfo->type); - builder.emitStore(anyValueVar, param); - auto resultVar = builder.emitVar(type); - - TypeUnpackingContext context; - context.anyValInfo = anyValInfo; - context.fieldOffset = context.intraFieldOffset = 0; - context.uintPtrType = builder.getPtrType(builder.getUIntType()); - context.anyValueVar = anyValueVar; - emitMarshallingCode(&builder, &context, resultVar); - auto load = builder.emitLoad(resultVar); - builder.emitReturn(load); - return func; } + }; - // Ensures the marshalling functions between `type` and `anyValueType` are already generated. - // Returns the generated marshalling functions. - MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType) - { - auto size = getIntVal(anyValueType->getSize()); - MarshallingFunctionKey key; - key.originalType = type; - key.anyValueSize = size; - MarshallingFunctionSet funcSet; - if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet)) - return funcSet; - funcSet.packFunc = generatePackingFunc(type, anyValueType); - funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType); - mapTypeMarshalingFunctions[key] = funcSet; + IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType) + { + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(type); + auto anyValInfo = ensureAnyValueType(anyValueType); + + auto func = builder.createFunc(); + + StringBuilder nameSb; + nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize()); + builder.addNameHintDecoration(func, nameSb.getUnownedSlice()); + + auto funcType = builder.getFuncType(1, &anyValInfo->type, type); + func->setFullType(funcType); + builder.setInsertInto(func); + + builder.emitBlock(); + + auto param = builder.emitParam(anyValInfo->type); + auto anyValueVar = builder.emitVar(anyValInfo->type); + builder.emitStore(anyValueVar, param); + auto resultVar = builder.emitVar(type); + + TypeUnpackingContext context; + context.anyValInfo = anyValInfo; + context.fieldOffset = context.intraFieldOffset = 0; + context.uintPtrType = builder.getPtrType(builder.getUIntType()); + context.anyValueVar = anyValueVar; + emitMarshallingCode(&builder, &context, resultVar); + auto load = builder.emitLoad(resultVar); + builder.emitReturn(load); + return func; + } + + // Ensures the marshalling functions between `type` and `anyValueType` are already generated. + // Returns the generated marshalling functions. + MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType) + { + auto size = getIntVal(anyValueType->getSize()); + MarshallingFunctionKey key; + key.originalType = type; + key.anyValueSize = size; + MarshallingFunctionSet funcSet; + if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet)) return funcSet; - } + funcSet.packFunc = generatePackingFunc(type, anyValueType); + funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType); + mapTypeMarshalingFunctions[key] = funcSet; + return funcSet; + } - void processPackInst(IRPackAnyValue* packInst) - { - auto operand = packInst->getValue(); - auto func = ensureMarshallingFunc( - operand->getDataType(), - cast<IRAnyValueType>(packInst->getDataType())); - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(packInst); - auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand); - packInst->replaceUsesWith(callInst); - packInst->removeAndDeallocate(); - } + void processPackInst(IRPackAnyValue* packInst) + { + auto operand = packInst->getValue(); + auto func = ensureMarshallingFunc( + operand->getDataType(), + cast<IRAnyValueType>(packInst->getDataType())); + IRBuilder builderStorage(sharedContext->module); + auto builder = &builderStorage; + builder->setInsertBefore(packInst); + auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand); + packInst->replaceUsesWith(callInst); + packInst->removeAndDeallocate(); + } - void processUnpackInst(IRUnpackAnyValue* unpackInst) - { - auto operand = unpackInst->getValue(); - auto func = ensureMarshallingFunc( - unpackInst->getDataType(), - cast<IRAnyValueType>(operand->getDataType())); - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(unpackInst); - auto callInst = builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand); - unpackInst->replaceUsesWith(callInst); - unpackInst->removeAndDeallocate(); - } + void processUnpackInst(IRUnpackAnyValue* unpackInst) + { + auto operand = unpackInst->getValue(); + auto func = ensureMarshallingFunc( + unpackInst->getDataType(), + cast<IRAnyValueType>(operand->getDataType())); + IRBuilder builderStorage(sharedContext->module); + auto builder = &builderStorage; + builder->setInsertBefore(unpackInst); + auto callInst = + builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand); + unpackInst->replaceUsesWith(callInst); + unpackInst->removeAndDeallocate(); + } - void processAnyValueType(IRAnyValueType* type) - { - auto info = ensureAnyValueType(type); - type->replaceUsesWith(info->type); - } + void processAnyValueType(IRAnyValueType* type) + { + auto info = ensureAnyValueType(type); + type->replaceUsesWith(info->type); + } - void processInst(IRInst* inst) + void processInst(IRInst* inst) + { + if (auto packInst = as<IRPackAnyValue>(inst)) { - if (auto packInst = as<IRPackAnyValue>(inst)) - { - processPackInst(packInst); - } - else if (auto unpackInst = as<IRUnpackAnyValue>(inst)) - { - processUnpackInst(unpackInst); - } + processPackInst(packInst); } - - void processModule() + else if (auto unpackInst = as<IRUnpackAnyValue>(inst)) { - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + processUnpackInst(unpackInst); + } + } - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); - processInst(inst); + sharedContext->workList.removeLast(); + sharedContext->workListSet.remove(inst); - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } + processInst(inst); - // Finally, replace all `AnyValueType` with the actual struct type that implements it. - for (auto inst : sharedContext->module->getModuleInst()->getChildren()) + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { - if (auto anyValueType = as<IRAnyValueType>(inst)) - processAnyValueType(anyValueType); + sharedContext->addToWorkList(child); } - sharedContext->mapInterfaceRequirementKeyValue.clear(); } - }; - void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext) - { - AnyValueMarshallingContext context; - context.sharedContext = sharedContext; - context.processModule(); + // Finally, replace all `AnyValueType` with the actual struct type that implements it. + for (auto inst : sharedContext->module->getModuleInst()->getChildren()) + { + if (auto anyValueType = as<IRAnyValueType>(inst)) + processAnyValueType(anyValueType); + } + sharedContext->mapInterfaceRequirementKeyValue.clear(); } +}; - SlangInt alignUp(SlangInt x, SlangInt alignment) - { - return (x + alignment - 1) / alignment * alignment; - } +void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext) +{ + AnyValueMarshallingContext context; + context.sharedContext = sharedContext; + context.processModule(); +} - SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) +SlangInt alignUp(SlangInt x, SlangInt alignment) +{ + return (x + alignment - 1) / alignment * alignment; +} + +SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) +{ + switch (type->getOp()) { - switch (type->getOp()) - { - case kIROp_IntType: - case kIROp_FloatType: - case kIROp_UIntType: - case kIROp_BoolType: - return alignUp(offset, 4) + 4; - case kIROp_UInt64Type: - case kIROp_Int64Type: - case kIROp_DoubleType: - return -1; - case kIROp_Int16Type: - case kIROp_UInt16Type: - case kIROp_HalfType: - return alignUp(offset, 2) + 2; - case kIROp_UInt8Type: - case kIROp_Int8Type: - return -1; - case kIROp_VectorType: + case kIROp_IntType: + case kIROp_FloatType: + case kIROp_UIntType: + case kIROp_BoolType: return alignUp(offset, 4) + 4; + case kIROp_UInt64Type: + case kIROp_Int64Type: + case kIROp_DoubleType: return -1; + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: return alignUp(offset, 2) + 2; + case kIROp_UInt8Type: + case kIROp_Int8Type: return -1; + case kIROp_VectorType: { auto vectorType = static_cast<IRVectorType*>(type); auto elementType = vectorType->getElementType(); @@ -715,11 +745,12 @@ namespace Slang for (IRIntegerValue i = 0; i < elementCount; i++) { offset = _getAnyValueSizeRaw(elementType, offset); - if (offset < 0) return offset; + if (offset < 0) + return offset; } return offset; } - case kIROp_MatrixType: + case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(type); auto elementType = matrixType->getElementType(); @@ -730,83 +761,95 @@ namespace Slang for (IRIntegerValue j = 0; j < rowCount; j++) { offset = _getAnyValueSizeRaw(elementType, offset); - if (offset < 0) return offset; + if (offset < 0) + return offset; } } return offset; } - case kIROp_StructType: + case kIROp_StructType: { auto structType = cast<IRStructType>(type); for (auto field : structType->getFields()) { offset = _getAnyValueSizeRaw(field->getFieldType(), offset); - if (offset < 0) return offset; + if (offset < 0) + return offset; } return offset; } - case kIROp_ArrayType: + case kIROp_ArrayType: { auto arrayType = cast<IRArrayType>(type); for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++) { offset = _getAnyValueSizeRaw(arrayType->getElementType(), offset); - if (offset < 0) return offset; + if (offset < 0) + return offset; } return offset; } - case kIROp_AnyValueType: + case kIROp_AnyValueType: { auto anyValueType = cast<IRAnyValueType>(type); return alignUp(offset, 4) + (SlangInt)getIntVal(anyValueType->getSize()); } - case kIROp_TupleType: + case kIROp_TupleType: { auto tupleType = cast<IRTupleType>(type); for (UInt i = 0; i < tupleType->getOperandCount(); i++) { auto elementType = tupleType->getOperand(i); offset = _getAnyValueSizeRaw((IRType*)elementType, offset); - if (offset < 0) return offset; + if (offset < 0) + return offset; } return offset; } - case kIROp_WitnessTableType: - case kIROp_WitnessTableIDType: - case kIROp_RTTIHandleType: + case kIROp_WitnessTableType: + case kIROp_WitnessTableIDType: + case kIROp_RTTIHandleType: { return alignUp(offset, 4) + kRTTIHandleSize; } - case kIROp_InterfaceType: + case kIROp_InterfaceType: { auto interfaceType = cast<IRInterfaceType>(type); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( + interfaceType, + interfaceType->sourceLoc); size += kRTTIHeaderSize; return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } - case kIROp_AssociatedType: + case kIROp_AssociatedType: { auto associatedType = cast<IRAssociatedType>(type); SlangInt maxSize = 0; for (UInt i = 0; i < associatedType->getOperandCount(); i++) - maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset)); + maxSize = Math::Max( + maxSize, + _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset)); return maxSize; } - case kIROp_ThisType: + case kIROp_ThisType: { auto thisType = cast<IRThisType>(type); auto interfaceType = thisType->getConstraintType(); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( + interfaceType, + interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } - case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialType: { auto existentialValue = type->getOperand(0); auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType()); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( + interfaceType, + interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } - case kIROp_LookupWitness: + case kIROp_LookupWitness: { auto witnessTableVal = type->getOperand(0); auto key = type->getOperand(1); @@ -815,13 +858,13 @@ namespace Slang { auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType()); - // Walk through interface operands to find a match, the result should be an + // Walk through interface operands to find a match, the result should be an // associated type entry. // for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++) { auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii)); - if (entry->getRequirementKey() == key && + if (entry->getRequirementKey() == key && as<IRAssociatedType>(entry->getRequirementVal())) { assocType = (IRType*)entry->getRequirementVal(); @@ -832,33 +875,36 @@ namespace Slang if (!assocType) return -1; - + IRIntegerValue anyValueSize = kInvalidAnyValueSize; for (UInt i = 0; i < assocType->getOperandCount(); i++) { anyValueSize = Math::Min( anyValueSize, - SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc)); + SharedGenericsLoweringContext::getInterfaceAnyValueSize( + assocType->getOperand(i), + type->sourceLoc)); } if (anyValueSize == kInvalidAnyValueSize) return -1; - + return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4); } - default: - if (isResourceType(type)) - { - return alignUp(offset, 4) + 8; - } - return -1; + default: + if (isResourceType(type)) + { + return alignUp(offset, 4) + 8; } + return -1; } +} - SlangInt getAnyValueSize(IRType* type) - { - auto rawSize = _getAnyValueSizeRaw(type, 0); - if (rawSize < 0) return rawSize; - return alignUp(rawSize, 4); - } +SlangInt getAnyValueSize(IRType* type) +{ + auto rawSize = _getAnyValueSizeRaw(type, 0); + if (rawSize < 0) + return rawSize; + return alignUp(rawSize, 4); } +} // namespace Slang |
