From 28adf8917e53953dbfebd746410a427a55eed814 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 9 Sep 2021 11:39:04 -0700 Subject: `reinterpret` and 16-bit value packing. (#1933) * `reinterpret` and 16-bit value packing. * Update `half-texture` cross-compile test reference result. * Revert inadvertent reformatting of slang-ir-inst-defs.h Co-authored-by: Yong He --- source/slang/core.meta.slang | 6 + source/slang/slang-diagnostic-defs.h | 3 + source/slang/slang-emit-hlsl.cpp | 31 ++- source/slang/slang-emit.cpp | 3 + source/slang/slang-ir-any-value-marshalling.cpp | 241 ++++++++++++++++++++++-- source/slang/slang-ir-any-value-marshalling.h | 7 + source/slang/slang-ir-inst-defs.h | 2 +- source/slang/slang-ir-insts.h | 1 + source/slang/slang-ir-lower-reinterpret.cpp | 109 +++++++++++ source/slang/slang-ir-lower-reinterpret.h | 16 ++ source/slang/slang-ir.cpp | 12 ++ 11 files changed, 406 insertions(+), 25 deletions(-) create mode 100644 source/slang/slang-ir-lower-reinterpret.cpp create mode 100644 source/slang/slang-ir-lower-reinterpret.h (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 28fd1a545..d2f574ef8 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1907,6 +1907,12 @@ __generic __intrinsic_op($(kIROp_CreateExistentialObject)) T createDynamicObject(uint typeId, U value); +// Reinterpret +__generic +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_Reinterpret)) +T reinterpret(U value); + // Specialized function /// Given a string returns an integer hash of that string. diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index dcc62bdf7..b1f1bbaef 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -503,6 +503,9 @@ DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.") + +DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") + // // 5xxxx - Target code generation. // diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 3921cbbac..2fda8ab99 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -479,7 +479,9 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu emitType(inst->getDataType()); m_writer->emit(")"); break; - + case BaseType::Half: + m_writer->emit("f16tof32"); + break; case BaseType::Float: // Note: at present HLSL only supports // reinterpreting integer bits as a `float`. @@ -511,11 +513,18 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu case BaseType::UInt: case BaseType::Int: break; - + case BaseType::UInt16: + case BaseType::Int16: + break; case BaseType::Float: m_writer->emit("asuint("); closeCount++; break; + + case BaseType::Half: + m_writer->emit("f32tof16("); + closeCount++; + break; } emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); @@ -750,20 +759,32 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_VoidType: case kIROp_BoolType: case kIROp_Int8Type: - case kIROp_Int16Type: case kIROp_IntType: case kIROp_Int64Type: case kIROp_UInt8Type: - case kIROp_UInt16Type: case kIROp_UIntType: case kIROp_UInt64Type: case kIROp_FloatType: case kIROp_DoubleType: - case kIROp_HalfType: { m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); return; } + case kIROp_Int16Type: + { + m_writer->emit("min16int"); + return; + } + case kIROp_UInt16Type: + { + m_writer->emit("min16uint"); + return; + } + case kIROp_HalfType: + { + m_writer->emit("min16float"); + return; + } case kIROp_StructType: m_writer->emit(getName(type)); return; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 3da19cef1..951e6fad5 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -21,6 +21,7 @@ #include "slang-ir-lower-generics.h" #include "slang-ir-lower-tuple-types.h" #include "slang-ir-lower-bit-cast.h" +#include "slang-ir-lower-reinterpret.h" #include "slang-ir-optix-entry-point-uniforms.h" #include "slang-ir-restructure.h" #include "slang-ir-restructure-scoping.h" @@ -327,6 +328,8 @@ Result linkAndOptimizeIR( eliminateDeadCode(irModule); + lowerReinterpret(targetRequest, irModule, sink); + // For targets that supports dynamic dispatch, we need to lower the // generics / interface types to ordinary functions and types using // function pointers. diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index cd76e4430..b88dd79d1 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -99,6 +99,33 @@ namespace Slang sink->diagnose(concreteType->sourceLoc, Diagnostics::typeDoesNotFitAnyValueSize, concreteType); } } + void ensureOffsetAt4ByteBoundary() + { + if (intraFieldOffset) + { + fieldOffset++; + intraFieldOffset = 0; + } + } + void ensureOffsetAt2ByteBoundary() + { + if (intraFieldOffset == 0) + return; + if (intraFieldOffset <= 2) + { + intraFieldOffset = 2; + return; + } + fieldOffset++; + intraFieldOffset = 0; + return; + } + void advanceOffset(uint32_t bytes) + { + intraFieldOffset += bytes; + fieldOffset += intraFieldOffset / 4; + intraFieldOffset = intraFieldOffset % 4; + } }; void emitMarshallingCode( @@ -208,6 +235,7 @@ namespace Slang case kIROp_IntType: case kIROp_FloatType: { + ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcVal = builder->emitLoad(concreteVar); @@ -218,11 +246,12 @@ namespace Slang anyValInfo->fieldKeys[fieldOffset]); builder->emitStore(dstAddr, dstVal); } - fieldOffset++; + advanceOffset(4); break; } case kIROp_UIntType: { + ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcVal = builder->emitLoad(concreteVar); @@ -232,17 +261,56 @@ namespace Slang anyValInfo->fieldKeys[fieldOffset]); builder->emitStore(dstAddr, srcVal); } - fieldOffset++; + advanceOffset(4); + break; + } + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: + { + ensureOffsetAt2ByteBoundary(); + if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) + { + auto srcVal = builder->emitLoad(concreteVar); + if (dataType->getOp() == kIROp_HalfType) + { + srcVal = builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal); + } + srcVal = builder->emitConstructorInst(builder->getType(kIROp_UIntType), 1, &srcVal); + auto dstAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + auto dstVal = builder->emitLoad(dstAddr); + if (intraFieldOffset == 0) + { + srcVal = builder->emitBitAnd( + srcVal->getFullType(), srcVal, + builder->getIntValue(builder->getUIntType(), 0xFFFF)); + dstVal = builder->emitBitAnd( + dstVal->getFullType(), dstVal, + builder->getIntValue(builder->getUIntType(), 0xFFFF0000)); + } + else + { + srcVal = builder->emitShl( + srcVal->getFullType(), srcVal, + builder->getIntValue(builder->getUIntType(), 16)); + dstVal = builder->emitBitAnd( + dstVal->getFullType(), dstVal, + builder->getIntValue(builder->getUIntType(), 0xFFFF)); + } + dstVal = builder->emitBitOr(dstVal->getFullType(), dstVal, srcVal); + builder->emitStore(dstAddr, dstVal); + } + advanceOffset(2); break; } case kIROp_UInt64Type: case kIROp_Int64Type: case kIROp_DoubleType: case kIROp_Int8Type: - case kIROp_Int16Type: case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_HalfType: SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); break; default: @@ -253,6 +321,7 @@ namespace Slang virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override { SLANG_UNUSED(dataType); + ensureOffsetAt4ByteBoundary(); if (fieldOffset + 1 < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcVal = builder->emitLoad(concreteVar); @@ -269,7 +338,7 @@ namespace Slang auto dstAddr2 = builder->emitFieldAddress( uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]); builder->emitStore(dstAddr2, highBits); - fieldOffset += 2; + advanceOffset(8); } } }; @@ -302,6 +371,13 @@ namespace Slang builder.emitStore(concreteTypedVar, param); auto resultVar = builder.emitVar(anyValInfo->type); + // Initialize fields to 0 to prevent downstream compiler error. + for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++) + { + auto fieldAddr = builder.emitFieldAddress(builder.getUIntType(), resultVar, anyValInfo->fieldKeys[offset]); + builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0)); + } + TypePackingContext context; context.anyValInfo = anyValInfo; context.fieldOffset = context.intraFieldOffset = 0; @@ -311,13 +387,6 @@ namespace Slang context.validateAnyTypeSize(sharedContext->sink, type); - // Initialize the rest of unused fields to 0 to prevent downstream compiler error. - for (uint32_t offset = context.fieldOffset; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++) - { - auto fieldAddr = builder.emitFieldAddress(builder.getUIntType(), resultVar, context.anyValInfo->fieldKeys[offset]); - builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0)); - } - auto load = builder.emitLoad(resultVar); builder.emitReturn(load); return func; @@ -332,6 +401,7 @@ namespace Slang case kIROp_IntType: case kIROp_FloatType: { + ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcAddr = builder->emitFieldAddress( @@ -342,11 +412,12 @@ namespace Slang srcVal = builder->emitBitCast(dataType, srcVal); builder->emitStore(concreteVar, srcVal); } - fieldOffset++; + advanceOffset(4); break; } case kIROp_UIntType: { + ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcAddr = builder->emitFieldAddress( @@ -356,17 +427,55 @@ namespace Slang auto srcVal = builder->emitLoad(srcAddr); builder->emitStore(concreteVar, srcVal); } - fieldOffset++; + advanceOffset(4); + break; + } + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: + { + ensureOffsetAt2ByteBoundary(); + if (fieldOffset < static_cast(anyValInfo->fieldKeys.getCount())) + { + auto srcAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + auto srcVal = builder->emitLoad(srcAddr); + if (intraFieldOffset == 0) + { + srcVal = builder->emitBitAnd( + srcVal->getFullType(), srcVal, + builder->getIntValue(builder->getUIntType(), 0xFFFF)); + } + else + { + srcVal = builder->emitShr( + srcVal->getFullType(), srcVal, + builder->getIntValue(builder->getUIntType(), 16)); + } + if (dataType->getOp() == kIROp_Int16Type) + { + srcVal = builder->emitConstructorInst(builder->getType(kIROp_Int16Type), 1, &srcVal); + } + else + { + srcVal = builder->emitConstructorInst(builder->getType(kIROp_UInt16Type), 1, &srcVal); + } + if (dataType->getOp() == kIROp_HalfType) + { + srcVal = builder->emitBitCast(dataType, srcVal); + } + builder->emitStore(concreteVar, srcVal); + } + advanceOffset(2); break; } case kIROp_UInt64Type: case kIROp_Int64Type: case kIROp_DoubleType: case kIROp_Int8Type: - case kIROp_Int16Type: case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_HalfType: SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements"); break; default: @@ -377,6 +486,7 @@ namespace Slang virtual void marshalResourceHandle( IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override { + ensureOffsetAt4ByteBoundary(); if (fieldOffset + 1 < static_cast(anyValInfo->fieldKeys.getCount())) { auto srcAddr = builder->emitFieldAddress( @@ -390,7 +500,7 @@ namespace Slang auto combinedBits = builder->emitMakeUInt64(lowBits, highBits); combinedBits = builder->emitBitCast(dataType, combinedBits); builder->emitStore(concreteVar, combinedBits); - fieldOffset += 2; + advanceOffset(8); } } }; @@ -541,4 +651,97 @@ namespace Slang context.sharedContext = sharedContext; context.processModule(); } + + SlangInt alignUp(SlangInt x, SlangInt alignment) + { + return (x + alignment - 1) / alignment * alignment; + } + + SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) + { + switch (type->getOp()) + { + case kIROp_IntType: + case kIROp_FloatType: + case kIROp_UIntType: + return alignUp(offset, 4) + 4; + case kIROp_UInt64Type: + case kIROp_Int64Type: + case kIROp_DoubleType: + return -1; + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: + return alignUp(offset, 2) + 2; + case kIROp_UInt8Type: + case kIROp_Int8Type: + return -1; + case kIROp_VectorType: + { + auto vectorType = static_cast(type); + auto elementType = vectorType->getElementType(); + auto elementCount = getIntVal(vectorType->getElementCount()); + for (IRIntegerValue i = 0; i < elementCount; i++) + { + offset = _getAnyValueSizeRaw(elementType, offset); + if (offset < 0) return offset; + } + return offset; + } + case kIROp_MatrixType: + { + auto matrixType = static_cast(type); + auto elementType = matrixType->getElementType(); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto rowCount = getIntVal(matrixType->getRowCount()); + for (IRIntegerValue i = 0; i < colCount; i++) + { + for (IRIntegerValue j = 0; j < rowCount; j++) + { + offset = _getAnyValueSizeRaw(elementType, offset); + if (offset < 0) return offset; + } + } + return offset; + } + case kIROp_StructType: + { + auto structType = cast(type); + for (auto field : structType->getFields()) + { + offset = _getAnyValueSizeRaw(field->getFieldType(), offset); + if (offset < 0) return offset; + } + return offset; + } + case kIROp_ArrayType: + { + auto arrayType = cast(type); + for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++) + { + offset = _getAnyValueSizeRaw(arrayType->getElementType(), offset); + if (offset < 0) return offset; + } + return offset; + } + case kIROp_InterfaceType: + { + // TODO: implement anyValue packing for interface types. + return -1; + } + default: + if (as(type) || as(type)) + { + return alignUp(offset, 4) + 8; + } + return -1; + } + } + + SlangInt getAnyValueSize(IRType* type) + { + auto rawSize = _getAnyValueSizeRaw(type, 0); + if (rawSize < 0) return rawSize; + return alignUp(rawSize, 4); + } } diff --git a/source/slang/slang-ir-any-value-marshalling.h b/source/slang/slang-ir-any-value-marshalling.h index 943b61aa3..3279eeaac 100644 --- a/source/slang/slang-ir-any-value-marshalling.h +++ b/source/slang/slang-ir-any-value-marshalling.h @@ -1,8 +1,11 @@ // slang-ir-any-value-marshalling.h #pragma once +#include "../core/slang-common.h" + namespace Slang { + struct IRType; struct SharedGenericsLoweringContext; /// Generates functions that pack and unpack `AnyValue`s, and replaces @@ -11,4 +14,8 @@ namespace Slang /// This is a sub-pass of lower-generics. void generateAnyValueMarshallingFunctions( SharedGenericsLoweringContext* sharedContext); + + + /// Get the AnyValue size required to hold a value of `type`. + SlangInt getAnyValueSize(IRType* type); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 797d9f69c..8e8a243a0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -668,6 +668,7 @@ INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) INST(BitCast, bitCast, 1, 0) +INST(Reinterpret, reinterpret, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) @@ -708,4 +709,3 @@ INST_RANGE(Attr, PendingLayoutAttr, VarOffsetAttr) #undef USE_OTHER #undef INST_RANGE #undef INST - diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0a1f3be8d..049801db1 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2446,6 +2446,7 @@ struct IRBuilder IRInst* emitWaveMaskMatch(IRType* type, IRInst* mask, IRInst* value); IRInst* emitBitAnd(IRType* type, IRInst* left, IRInst* right); + IRInst* emitBitOr(IRType* type, IRInst* left, IRInst* right); IRInst* emitBitNot(IRType* type, IRInst* value); IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right); diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp new file mode 100644 index 000000000..a140bacad --- /dev/null +++ b/source/slang/slang-ir-lower-reinterpret.cpp @@ -0,0 +1,109 @@ +#include "slang-ir-lower-reinterpret.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-layout.h" +#include "slang-ir-any-value-marshalling.h" + +namespace Slang +{ + +struct ReinterpretLoweringContext +{ + TargetRequest* targetReq; + DiagnosticSink* sink; + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + OrderedHashSet workList; + + void addToWorkList(IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as(ii)) + return; + } + + if (workList.Contains(inst)) + return; + + workList.Add(inst); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Reinterpret: + processReinterpret(inst); + break; + default: + break; + } + } + + void processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = module; + sharedBuilder->session = module->session; + + // Deduplicate equivalent types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + addToWorkList(module->getModuleInst()); + + while (workList.Count() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + + void processReinterpret(IRInst* inst) + { + auto operand = inst->getOperand(0); + auto fromType = operand->getDataType(); + auto toType = inst->getDataType(); + SlangInt fromTypeSize = getAnyValueSize(fromType); + if (fromTypeSize < 0) + { + sink->diagnose(inst->sourceLoc, Slang::Diagnostics::typeCannotBePackedIntoAnyValue, fromType); + } + SlangInt toTypeSize = getAnyValueSize(toType); + if (toTypeSize < 0) + { + sink->diagnose(inst->sourceLoc, Slang::Diagnostics::typeCannotBePackedIntoAnyValue, toType); + } + SlangInt anyValueSize = Math::Max(fromTypeSize, toTypeSize); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + auto anyValueType = builder.getAnyValueType(builder.getIntValue(builder.getUIntType(), anyValueSize)); + auto packInst = builder.emitPackAnyValue( + anyValueType, + operand); + auto unpackInst = builder.emitUnpackAnyValue(toType, packInst); + inst->replaceUsesWith(unpackInst); + inst->removeAndDeallocate(); + } +}; + +void lowerReinterpret(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink) +{ + ReinterpretLoweringContext context; + context.module = module; + context.targetReq = targetReq; + context.sink = sink; + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-lower-reinterpret.h b/source/slang/slang-ir-lower-reinterpret.h new file mode 100644 index 000000000..623ccb32e --- /dev/null +++ b/source/slang/slang-ir-lower-reinterpret.h @@ -0,0 +1,16 @@ +// slang-ir-lower-reinterpret.h +#pragma once + +// This file defines an IR pass that lowers a reinterept(U) operation, where T and U are any ordinary data types, +// into a packAnyValue followed by a unpackAnyValue operation. + +namespace Slang +{ + +struct IRModule; +class TargetRequest; +class DiagnosticSink; + +void lowerReinterpret(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 60aaafa83..601ebca26 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4100,6 +4100,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBitOr(IRType* type, IRInst* left, IRInst* right) + { + auto inst = createInst( + this, + kIROp_BitOr, + type, + left, + right); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBitNot(IRType* type, IRInst* value) { auto inst = createInst( -- cgit v1.2.3