diff options
| author | Yong He <yonghe@outlook.com> | 2021-02-12 12:20:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-02-12 12:20:17 -0800 |
| commit | a2401a6ae6c50aeb6ffc196144569bb5253cdf95 (patch) | |
| tree | b69f68b0d5f81ab2d782bfa3ad125637c8f39d96 | |
| parent | 369279e91dde1b056d8d0e3bb83e7ba3f96321af (diff) | |
Support `bit_cast` between complex types. (#1702)
* Support `bit_cast` between complex types.
* Fix vs project file
* Fix clang build error
* fix
* fix
* Fix
* FIx
* Fix
* Fix
* Fix
* Fix
* Fix linux compile error
Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 4 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 12 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-extract-value-from-type.cpp | 278 | ||||
| -rw-r--r-- | source/slang/slang-ir-extract-value-from-type.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-layout.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-bit-cast.cpp | 259 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-bit-cast.h | 15 | ||||
| -rw-r--r-- | tests/language-feature/bit-cast/struct-bit-cast.slang | 68 | ||||
| -rw-r--r-- | tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt | 4 |
11 files changed, 692 insertions, 0 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 55233f013..6a8798d59 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -233,6 +233,7 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-entry-point-uniforms.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-context.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-init.h" /> + <ClInclude Include="..\..\..\source\slang\slang-ir-extract-value-from-type.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-generics-lowering-context.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-glsl-legalize.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-hoist-local-types.h" /> @@ -242,6 +243,7 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-layout.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-legalize-varying-params.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-link.h" /> + <ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-call.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-function.h" /> @@ -361,6 +363,7 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-entry-point-uniforms.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-context.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-init.cpp" /> + <ClCompile Include="..\..\..\source\slang\slang-ir-extract-value-from-type.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-generics-lowering-context.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-glsl-legalize.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-hoist-local-types.cpp" /> @@ -369,6 +372,7 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-legalize-types.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-legalize-varying-params.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp" /> + <ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-call.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-function.cpp" /> diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index a31688e05..a6fa6ea0c 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -150,6 +150,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-init.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="..\..\..\source\slang\slang-ir-extract-value-from-type.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="..\..\..\source\slang\slang-ir-generics-lowering-context.h"> <Filter>Header Files</Filter> </ClInclude> @@ -177,6 +180,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-link.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h"> <Filter>Header Files</Filter> </ClInclude> @@ -530,6 +536,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-init.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="..\..\..\source\slang\slang-ir-extract-value-from-type.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="..\..\..\source\slang\slang-ir-generics-lowering-context.cpp"> <Filter>Source Files</Filter> </ClCompile> @@ -554,6 +563,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 55f6c607b..887852cbc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1913,6 +1913,11 @@ ${{{{ }}}} +// Bit cast +__generic<T, U> +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitCast)) +T bit_cast(U value); // Specialized function diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 20e8c0beb..c9de26217 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -18,6 +18,7 @@ #include "slang-ir-link.h" #include "slang-ir-lower-generics.h" #include "slang-ir-lower-tuple-types.h" +#include "slang-ir-lower-bit-cast.h" #include "slang-ir-restructure.h" #include "slang-ir-restructure-scoping.h" #include "slang-ir-specialize.h" @@ -682,6 +683,12 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(compileRequest, irModule); + // Lower all bit_cast operations on complex types into leaf-level + // bit_cast on basic types. + lowerBitCast(targetRequest, irModule); + eliminateDeadCode(irModule); + validateIRModuleIfEnabled(compileRequest, irModule); + return SLANG_OK; } diff --git a/source/slang/slang-ir-extract-value-from-type.cpp b/source/slang/slang-ir-extract-value-from-type.cpp new file mode 100644 index 000000000..3019b5fd4 --- /dev/null +++ b/source/slang/slang-ir-extract-value-from-type.cpp @@ -0,0 +1,278 @@ +#include "slang-ir-extract-value-from-type.h" +#include "slang-ir-layout.h" +#include "slang-ir-insts.h" +#define CHECK(x) SLANG_RELEASE_ASSERT((x) == SLANG_OK) + +namespace Slang +{ + +// Represents the result of finding the leaf-level value in a type that contains the +// the entirety or the first half of the requested value at the specified offset. +struct FindLeafValueResult +{ + IRInst* leafValue = nullptr; // The leaf-level value. + uint32_t valueSize = 0; // The size of the leaf-level value. + uint32_t offsetInValue = 0; // The offset in bytes within `leafValue` that contains the requested value. +}; + +FindLeafValueResult findLeafValueAtOffset( + TargetRequest* targetReq, + IRBuilder& builder, + IRType* dataType, + IRSizeAndAlignment& layout, + IRInst* src, + uint32_t offset) +{ + FindLeafValueResult result; + if (offset >= layout.size && offset < layout.getStride()) + { + // We are extracting bits beyond the type size but within the stride boundary, + // return a 0 value in this case. + result.leafValue = builder.getIntValue(builder.getUIntType(), 0); + result.valueSize = 4; + result.offsetInValue = (uint32_t)(offset - layout.size); + return result; + } + switch (dataType->op) + { + case kIROp_StructType: + { + auto structType = as<IRStructType>(dataType); + for (auto field : structType->getFields()) + { + IRIntegerValue fieldOffset = 0; + IRSizeAndAlignment fieldLayout; + CHECK(getNaturalSizeAndAlignment(targetReq, field->getFieldType(), &fieldLayout)); + CHECK(getNaturalOffset(targetReq, field, &fieldOffset)); + if (fieldOffset + fieldLayout.size > offset) + { + if (fieldOffset > offset) + { + // This field is starting after the requested offset, + // therefore the requested value is located at the "gap" + // between aligned fields, in this case the requested value + // is 0. + result.leafValue = builder.getIntValue(builder.getUIntType(), 0); + result.valueSize = 4; + result.offsetInValue = (uint32_t)(fieldOffset - offset); + return result; + } + // The field contains requested value. We want to recursively + // traverse the field type to reach a leaf case. + auto fieldValue = + builder.emitFieldExtract(field->getFieldType(), src, field->getKey()); + return findLeafValueAtOffset( + targetReq, + builder, + field->getFieldType(), + fieldLayout, + fieldValue, + (uint32_t)(offset - fieldOffset)); + } + } + result.leafValue = builder.getIntValue(builder.getUIntType(), 0); + result.valueSize = 4; + result.offsetInValue = (uint32_t)(offset - layout.size); + return result; + } + break; + case kIROp_ArrayType: + { + auto arrayType = as<IRArrayType>(dataType); + auto elementType = arrayType->getElementType(); + IRSizeAndAlignment elementLayout; + CHECK(getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout)); + if (elementLayout.getStride() == 0) + { + result.leafValue = builder.getIntValue(builder.getUIntType(), 0); + result.valueSize = 4; + result.offsetInValue = 0; + return result; + } + uint32_t index = offset / (uint32_t)elementLayout.getStride(); + auto elementValue = builder.emitElementExtract( + elementType, src, builder.getIntValue(builder.getUIntType(), index)); + return findLeafValueAtOffset( + targetReq, + builder, + elementType, + elementLayout, + elementValue, + (uint32_t)(offset - elementLayout.getStride() * index)); + } + break; + case kIROp_VectorType: + { + auto vectorType = as<IRVectorType>(dataType); + auto elementType = vectorType->getElementType(); + IRSizeAndAlignment elementLayout; + CHECK(getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout)); + uint32_t index = + elementLayout.getStride() == 0 ? 0 : (uint32_t)(offset / elementLayout.getStride()); + auto elementValue = builder.emitElementExtract( + elementType, src, builder.getIntValue(builder.getUIntType(), index)); + return findLeafValueAtOffset( + targetReq, + builder, + elementType, + elementLayout, + elementValue, + (uint32_t)(offset - elementLayout.getStride() * index)); + } + break; + case kIROp_MatrixType: + { + // Note: this code is assuming row major odering. + auto matrixType = as<IRMatrixType>(dataType); + auto elementType = matrixType->getElementType(); + SLANG_RELEASE_ASSERT(matrixType->getColumnCount()->op == kIROp_IntLit); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount())->value.intVal; + auto rowType = builder.getVectorType(elementType, matrixType->getColumnCount()); + IRSizeAndAlignment rowLayout; + CHECK(getNaturalSizeAndAlignment(targetReq, rowType, &rowLayout)); + uint32_t rowIndex = rowLayout.getStride() == 0 + ? 0 + : (uint32_t)(offset / (columnCount * rowLayout.getStride())); + auto rowValue = builder.emitElementExtract( + rowType, src, builder.getIntValue(builder.getUIntType(), rowIndex)); + return findLeafValueAtOffset( + targetReq, + builder, + rowType, + rowLayout, + rowValue, + (uint32_t)(offset - rowLayout.getStride() * rowIndex)); + } + break; + default: + { + result.leafValue = src; + result.offsetInValue = offset; + result.valueSize = (uint32_t)layout.size; + return result; + } + break; + } +} + +IRInst* extractByteAtOffset( + IRBuilder& builder, + TargetRequest* targetReq, + IRType* dataType, + IRSizeAndAlignment& layout, + IRInst* src, + uint32_t offset) +{ + auto leaf = findLeafValueAtOffset(targetReq, builder, dataType, layout, src, offset); + IRType* uintType = nullptr; + if (leaf.valueSize <= 4) + { + uintType = builder.getUIntType(); + } + else + { + uintType = builder.getUInt64Type(); + } + auto resultValue = builder.emitBitCast(uintType, leaf.leafValue); + if (leaf.offsetInValue != 0) + { + uint32_t shift = leaf.offsetInValue * 8; + resultValue = builder.emitShr(uintType, resultValue, builder.getIntValue(uintType, shift)); + + resultValue = builder.emitBitAnd( + builder.getUIntType(), + resultValue, builder.getIntValue(builder.getUIntType(), 0xFF)); + } + return resultValue; +} + +IRInst* extractMultiByteValueAtOffset( + IRBuilder& builder, + TargetRequest* targetReq, + IRType* dataType, + IRSizeAndAlignment& layout, + IRInst* src, + uint32_t size, + uint32_t offset) +{ + if (size == 1) + return extractByteAtOffset(builder, targetReq, dataType, layout, src, offset); + + auto leaf = findLeafValueAtOffset(targetReq, builder, dataType, layout, src, offset); + auto resultValue = leaf.leafValue; + IRType* uintType = nullptr; + if (leaf.valueSize <= 4) + { + uintType = builder.getUIntType(); + } + else + { + uintType = builder.getUInt64Type(); + } + if (leaf.valueSize - leaf.offsetInValue >= size) + { + // The request value is fully contained in the found leaf element. + // We can proceed to extract the requested bits from the element. + uint32_t shift = leaf.offsetInValue * 8; + if (shift > 0) + resultValue = builder.emitShr(uintType, resultValue, builder.getIntValue(uintType, shift)); + uint32_t bitMask = 0; + switch (size) + { + case 1: + bitMask = 0xFF; + break; + case 2: + bitMask = 0xFFFFF; + break; + case 3: + bitMask = 0xFFFFFF; + break; + case 4: + bitMask = 0xFFFFFFFF; + break; + default: + break; + } + if (leaf.valueSize != size) + { + resultValue = + builder.emitBitAnd(uintType, resultValue, builder.getIntValue(uintType, bitMask)); + } + return resultValue; + } + else + { + // The requested value crosses the boundaries of different fields. + // We need to extract first and second half separately, and combine them together. + auto firstHalf = extractMultiByteValueAtOffset( + builder, targetReq, dataType, layout, src, size / 2, offset); + auto secondHalf = extractMultiByteValueAtOffset( + builder, targetReq, dataType, layout, src, size / 2, offset + size / 2); + uint32_t shift = (size / 2) * 8; + resultValue = builder.emitAdd( + builder.getUIntType(), + firstHalf, + builder.emitShl( + builder.getUIntType(), + secondHalf, + builder.getIntValue(builder.getUIntType(), shift))); + return resultValue; + } +} + +IRInst* extractValueAtOffset( + IRBuilder& builder, TargetRequest* targetReq, IRInst* src, uint32_t offset, uint32_t size) +{ + auto dataType = src->getDataType(); + IRSizeAndAlignment typeLayout; + SLANG_RETURN_NULL_ON_FAIL(getNaturalSizeAndAlignment(targetReq, dataType, &typeLayout)); + if (offset + size > typeLayout.size) + { + return builder.getIntValue(builder.getIntType(), 0); + } + return extractMultiByteValueAtOffset( + builder, targetReq, dataType, typeLayout, src, size, offset); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-extract-value-from-type.h b/source/slang/slang-ir-extract-value-from-type.h new file mode 100644 index 000000000..de39ee545 --- /dev/null +++ b/source/slang/slang-ir-extract-value-from-type.h @@ -0,0 +1,16 @@ +// slang-ir-extract-value-from-type.h +#pragma once + +#include "slang-ir.h" +#include "slang-type-layout.h" + +namespace Slang +{ + +// Emit code using builder that yields an `IRInst` representing a value of `size` bytes +// starting at `offset` in `src`. `src` must be a value of `struct`, array, vector or basic type. +// `size` can be either 1, 2 or 4. The resulting `IRInst` value will have an `uint` type. +IRInst* extractValueAtOffset( + IRBuilder& builder, TargetRequest* targetReq, IRInst* src, uint32_t offset, uint32_t size); + +} diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 41b004372..ed3fff2c0 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -77,6 +77,12 @@ static Result _calcNaturalArraySizeAndAlignment( return SLANG_OK; } +IRIntegerValue getIntegerValueFromInst(IRInst* inst) +{ + SLANG_ASSERT(inst->op == kIROp_IntLit); + return as<IRIntLit>(inst)->value.intVal; +} + static Result _calcNaturalSizeAndAlignment( TargetRequest* target, IRType* type, @@ -192,6 +198,24 @@ static Result _calcNaturalSizeAndAlignment( } break; + case kIROp_MatrixType: + { + auto matType = cast<IRMatrixType>(type); + auto rowCount = getIntegerValueFromInst(matType->getRowCount()); + auto colCount = getIntegerValueFromInst(matType->getColumnCount()); + SharedIRBuilder sharedBuilder; + sharedBuilder.module = type->getModule(); + sharedBuilder.session = sharedBuilder.module->getSession(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + + return _calcNaturalArraySizeAndAlignment( + target, matType->getElementType(), + builder.getIntValue(builder.getUIntType(), rowCount * colCount), + outSizeAndAlignment); + } + break; default: break; } diff --git a/source/slang/slang-ir-lower-bit-cast.cpp b/source/slang/slang-ir-lower-bit-cast.cpp new file mode 100644 index 000000000..ed312741e --- /dev/null +++ b/source/slang/slang-ir-lower-bit-cast.cpp @@ -0,0 +1,259 @@ +#include "slang-ir-lower-bit-cast.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-extract-value-from-type.h" +#include "slang-ir-layout.h" + +namespace Slang +{ + +struct BitCastLoweringContext +{ + TargetRequest* targetReq; + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + OrderedHashSet<IRInst*> workList; + + void addToWorkList(IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as<IRGeneric>(ii)) + return; + } + + if (workList.Contains(inst)) + return; + + workList.Add(inst); + } + + void processInst(IRInst* inst) + { + switch (inst->op) + { + case kIROp_BitCast: + processBitCast(inst); + break; + default: + break; + } + } + + void processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = module; + sharedBuilder->session = module->session; + + // Deduplicate equivalent types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + addToWorkList(module->getModuleInst()); + + while (workList.Count() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + + + // Extract an object of `type` from `offset` in `src`. + IRInst* readObject(IRBuilder& builder, IRInst* src, IRType* type, uint32_t offset) + { + switch (type->op) + { + case kIROp_StructType: + { + auto structType = as<IRStructType>(type); + List<IRInst*> fieldValues; + for (auto field : structType->getFields()) + { + IRIntegerValue fieldOffset = 0; + SLANG_RELEASE_ASSERT( + getNaturalOffset(targetReq, field, &fieldOffset) == SLANG_OK); + auto fieldType = field->getFieldType(); + auto fieldValue = + readObject(builder, src, fieldType, (uint32_t)(fieldOffset + offset)); + fieldValues.add(fieldValue); + } + return builder.emitMakeStruct(structType, fieldValues); + } + break; + case kIROp_ArrayType: + { + auto arrayType = as<IRArrayType>(type); + auto arrayCount = as<IRIntLit>(arrayType->getElementCount()); + SLANG_RELEASE_ASSERT(arrayCount && "bit_cast: array size must be fixed."); + List<IRInst*> elements; + IRSizeAndAlignment elementLayout; + SLANG_RELEASE_ASSERT( + getNaturalSizeAndAlignment( + targetReq, arrayType->getElementType(), &elementLayout) == SLANG_OK); + for (IRIntegerValue i = 0; i < arrayCount->value.intVal; i++) + { + elements.add(readObject( + builder, + src, + arrayType->getElementType(), + (uint32_t)(offset + elementLayout.getStride() * i))); + } + return builder.emitMakeArray(arrayType, (UInt)arrayCount->value.intVal, elements.getBuffer()); + } + break; + case kIROp_VectorType: + { + auto vectorType = as<IRVectorType>(type); + auto elementCount = as<IRIntLit>(vectorType->getElementCount()); + SLANG_RELEASE_ASSERT(elementCount && "bit_cast: vector size must be int literal."); + List<IRInst*> elements; + IRSizeAndAlignment elementLayout; + SLANG_RELEASE_ASSERT( + getNaturalSizeAndAlignment( + targetReq, vectorType->getElementType(), &elementLayout) == SLANG_OK); + for (IRIntegerValue i = 0; i < elementCount->value.intVal; i++) + { + elements.add(readObject( + builder, + src, + vectorType->getElementType(), + (uint32_t)(offset + elementLayout.getStride() * i))); + } + return builder.emitMakeVector( + vectorType, (UInt)elementCount->value.intVal, elements.getBuffer()); + } + break; + case kIROp_MatrixType: + { + // Assuming row-major order + auto matrixType = as<IRMatrixType>(type); + auto elementCount = as<IRIntLit>(matrixType->getRowCount()); + SLANG_RELEASE_ASSERT( + elementCount && "bit_cast: vector size must be int literal."); + List<IRInst*> elements; + auto elementType = builder.getVectorType( + matrixType->getElementType(), matrixType->getColumnCount()); + IRSizeAndAlignment elementLayout; + SLANG_RELEASE_ASSERT( + getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout) == SLANG_OK); + for (IRIntegerValue i = 0; i < elementCount->value.intVal; i++) + { + elements.add(readObject( + builder, + src, + elementType, + (uint32_t)(offset + elementLayout.getStride() * i))); + } + return builder.emitMakeMatrix( + matrixType, (UInt)elementCount->value.intVal, elements.getBuffer()); + } + break; + case kIROp_HalfType: + case kIROp_Int16Type: + case kIROp_UInt16Type: + { + auto object = extractValueAtOffset(builder, targetReq, src, offset, 2); + return builder.emitBitCast(type, object); + } + break; + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_FloatType: + case kIROp_BoolType: + { + auto object = extractValueAtOffset(builder, targetReq, src, offset, 4); + return builder.emitBitCast(type, object); + } + break; + case kIROp_DoubleType: + case kIROp_Int64Type: + case kIROp_UInt64Type: + case kIROp_RawPointerType: + { + auto low = extractValueAtOffset(builder, targetReq, src, offset, 4); + auto high = extractValueAtOffset(builder, targetReq, src, offset + 4, 4); + auto combined = builder.emitAdd(builder.getUInt64Type(), + low, + builder.emitShl( + builder.getUInt64Type(), + high, + builder.getIntValue(builder.getUIntType(), 32))); + if (type->op == kIROp_UInt64Type) + return combined; + return builder.emitBitCast(type, combined); + } + break; + case kIROp_UInt8Type: + case kIROp_Int8Type: + { + auto object = extractValueAtOffset(builder, targetReq, src, offset, 1); + return builder.emitBitCast(type, object); + } + break; + default: + { + SLANG_UNEXPECTED("Unable to generate bit_cast code for the given type"); + } + break; + } + } + + void processBitCast(IRInst* inst) + { + auto operand = inst->getOperand(0); + auto fromType = operand->getDataType(); + auto toType = inst->getDataType(); + if (as<IRBasicType>(fromType) != nullptr && as<IRBasicType>(toType) != nullptr) + { + // Both fromType and toType are basic types, no processing needed. + return; + } + // Ignore cases we cannot handle yet. + if (as<IRPtrType>(fromType) || as<IRPtrType>(toType)) + { + return; + } + if (as<IRRawPointerType>(fromType) || as<IRRawPointerType>(toType)) + { + return; + } + if (as<IRResourceTypeBase>(fromType) || as<IRResourceTypeBase>(toType)) + { + return; + } + if (as<IRPointerLikeType>(fromType) || as<IRPointerLikeType>(toType)) + { + return; + } + if (as<IRSamplerStateTypeBase>(fromType) || as<IRSamplerStateTypeBase>(toType)) + { + return; + } + // Enumerate all fields in to-type and obtain its value from operand object. + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + auto finalObject = readObject(builder, operand, toType, 0); + inst->replaceUsesWith(finalObject); + inst->removeAndDeallocate(); + } +}; + +void lowerBitCast(TargetRequest* targetReq, IRModule* module) +{ + BitCastLoweringContext context; + context.module = module; + context.targetReq = targetReq; + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-lower-bit-cast.h b/source/slang/slang-ir-lower-bit-cast.h new file mode 100644 index 000000000..4e93dbe4f --- /dev/null +++ b/source/slang/slang-ir-lower-bit-cast.h @@ -0,0 +1,15 @@ +// slang-ir-lower-bit-cast.h +#pragma once + +// This file defines an IR pass that lowers a BitCast<T>(U) operation, where T and U are struct types, +// into a series of bit-cast operations on basic-typed elements. + +namespace Slang +{ + +struct IRModule; +class TargetRequest; + +void lowerBitCast(TargetRequest* targetReq, IRModule* module); + +} diff --git a/tests/language-feature/bit-cast/struct-bit-cast.slang b/tests/language-feature/bit-cast/struct-bit-cast.slang new file mode 100644 index 000000000..9c4a039c0 --- /dev/null +++ b/tests/language-feature/bit-cast/struct-bit-cast.slang @@ -0,0 +1,68 @@ +// struct-bit-cast.slang + +//TEST(compute):COMPARE_COMPUTE: -shaderobj + +// Test that bit_cast works for bit-reinterpreting one struct type as another. + +struct Foo +{ + uint a; + float b; + float2 fvec; +} + +struct Inner +{ + int v; + uint s; +} + +struct Bar +{ + int u; + Inner i; + uint t; +} + +int test0(int val) +{ + Bar b; + b.u = val; + b.i.v = asint(2.0f); + b.i.s = asuint(1.25); + b.t = asuint(0.25); + Foo f = bit_cast<Foo, Bar>(b); + return f.a + (int)f.b + int(float(f.fvec.x / f.fvec.y)); // val + 2 + 5 +} + +struct Smaller +{ + int s; +} + +struct Larger +{ + int x, y; +} + +int test1() +{ + Smaller s = {1}; + int v0 = bit_cast<Larger, Smaller>(s).y; // 0. + Larger l = {1, 2}; + int v1 = bit_cast<Smaller, Larger>(l).s; // 1. + return v0 + v1; +} + + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test0(inVal) + test1(); + outputBuffer[tid] = outVal; +} diff --git a/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt b/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt new file mode 100644 index 000000000..9ace6947d --- /dev/null +++ b/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt @@ -0,0 +1,4 @@ +8 +9 +A +B |
