summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-02-12 12:20:17 -0800
committerGitHub <noreply@github.com>2021-02-12 12:20:17 -0800
commita2401a6ae6c50aeb6ffc196144569bb5253cdf95 (patch)
treeb69f68b0d5f81ab2d782bfa3ad125637c8f39d96
parent369279e91dde1b056d8d0e3bb83e7ba3f96321af (diff)
Support `bit_cast` between complex types. (#1702)
* Support `bit_cast` between complex types. * Fix vs project file * Fix clang build error * fix * fix * Fix * FIx * Fix * Fix * Fix * Fix * Fix linux compile error Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj4
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters12
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-extract-value-from-type.cpp278
-rw-r--r--source/slang/slang-ir-extract-value-from-type.h16
-rw-r--r--source/slang/slang-ir-layout.cpp24
-rw-r--r--source/slang/slang-ir-lower-bit-cast.cpp259
-rw-r--r--source/slang/slang-ir-lower-bit-cast.h15
-rw-r--r--tests/language-feature/bit-cast/struct-bit-cast.slang68
-rw-r--r--tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt4
11 files changed, 692 insertions, 0 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 55233f013..6a8798d59 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -233,6 +233,7 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-entry-point-uniforms.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-context.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-init.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-extract-value-from-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-generics-lowering-context.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-glsl-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-hoist-local-types.h" />
@@ -242,6 +243,7 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-layout.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-legalize-varying-params.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-link.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-call.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-function.h" />
@@ -361,6 +363,7 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-entry-point-uniforms.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-context.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-init.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-extract-value-from-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-generics-lowering-context.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-glsl-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-hoist-local-types.cpp" />
@@ -369,6 +372,7 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-legalize-types.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-legalize-varying-params.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-call.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-function.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index a31688e05..a6fa6ea0c 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -150,6 +150,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-explicit-global-init.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-extract-value-from-type.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-generics-lowering-context.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -177,6 +180,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-link.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -530,6 +536,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-explicit-global-init.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-extract-value-from-type.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-generics-lowering-context.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -554,6 +563,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 55f6c607b..887852cbc 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1913,6 +1913,11 @@ ${{{{
}}}}
+// Bit cast
+__generic<T, U>
+[__unsafeForceInlineEarly]
+__intrinsic_op($(kIROp_BitCast))
+T bit_cast(U value);
// Specialized function
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 20e8c0beb..c9de26217 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -18,6 +18,7 @@
#include "slang-ir-link.h"
#include "slang-ir-lower-generics.h"
#include "slang-ir-lower-tuple-types.h"
+#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-restructure.h"
#include "slang-ir-restructure-scoping.h"
#include "slang-ir-specialize.h"
@@ -682,6 +683,12 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(compileRequest, irModule);
+ // Lower all bit_cast operations on complex types into leaf-level
+ // bit_cast on basic types.
+ lowerBitCast(targetRequest, irModule);
+ eliminateDeadCode(irModule);
+ validateIRModuleIfEnabled(compileRequest, irModule);
+
return SLANG_OK;
}
diff --git a/source/slang/slang-ir-extract-value-from-type.cpp b/source/slang/slang-ir-extract-value-from-type.cpp
new file mode 100644
index 000000000..3019b5fd4
--- /dev/null
+++ b/source/slang/slang-ir-extract-value-from-type.cpp
@@ -0,0 +1,278 @@
+#include "slang-ir-extract-value-from-type.h"
+#include "slang-ir-layout.h"
+#include "slang-ir-insts.h"
+#define CHECK(x) SLANG_RELEASE_ASSERT((x) == SLANG_OK)
+
+namespace Slang
+{
+
+// Represents the result of finding the leaf-level value in a type that contains the
+// the entirety or the first half of the requested value at the specified offset.
+struct FindLeafValueResult
+{
+ IRInst* leafValue = nullptr; // The leaf-level value.
+ uint32_t valueSize = 0; // The size of the leaf-level value.
+ uint32_t offsetInValue = 0; // The offset in bytes within `leafValue` that contains the requested value.
+};
+
+FindLeafValueResult findLeafValueAtOffset(
+ TargetRequest* targetReq,
+ IRBuilder& builder,
+ IRType* dataType,
+ IRSizeAndAlignment& layout,
+ IRInst* src,
+ uint32_t offset)
+{
+ FindLeafValueResult result;
+ if (offset >= layout.size && offset < layout.getStride())
+ {
+ // We are extracting bits beyond the type size but within the stride boundary,
+ // return a 0 value in this case.
+ result.leafValue = builder.getIntValue(builder.getUIntType(), 0);
+ result.valueSize = 4;
+ result.offsetInValue = (uint32_t)(offset - layout.size);
+ return result;
+ }
+ switch (dataType->op)
+ {
+ case kIROp_StructType:
+ {
+ auto structType = as<IRStructType>(dataType);
+ for (auto field : structType->getFields())
+ {
+ IRIntegerValue fieldOffset = 0;
+ IRSizeAndAlignment fieldLayout;
+ CHECK(getNaturalSizeAndAlignment(targetReq, field->getFieldType(), &fieldLayout));
+ CHECK(getNaturalOffset(targetReq, field, &fieldOffset));
+ if (fieldOffset + fieldLayout.size > offset)
+ {
+ if (fieldOffset > offset)
+ {
+ // This field is starting after the requested offset,
+ // therefore the requested value is located at the "gap"
+ // between aligned fields, in this case the requested value
+ // is 0.
+ result.leafValue = builder.getIntValue(builder.getUIntType(), 0);
+ result.valueSize = 4;
+ result.offsetInValue = (uint32_t)(fieldOffset - offset);
+ return result;
+ }
+ // The field contains requested value. We want to recursively
+ // traverse the field type to reach a leaf case.
+ auto fieldValue =
+ builder.emitFieldExtract(field->getFieldType(), src, field->getKey());
+ return findLeafValueAtOffset(
+ targetReq,
+ builder,
+ field->getFieldType(),
+ fieldLayout,
+ fieldValue,
+ (uint32_t)(offset - fieldOffset));
+ }
+ }
+ result.leafValue = builder.getIntValue(builder.getUIntType(), 0);
+ result.valueSize = 4;
+ result.offsetInValue = (uint32_t)(offset - layout.size);
+ return result;
+ }
+ break;
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(dataType);
+ auto elementType = arrayType->getElementType();
+ IRSizeAndAlignment elementLayout;
+ CHECK(getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout));
+ if (elementLayout.getStride() == 0)
+ {
+ result.leafValue = builder.getIntValue(builder.getUIntType(), 0);
+ result.valueSize = 4;
+ result.offsetInValue = 0;
+ return result;
+ }
+ uint32_t index = offset / (uint32_t)elementLayout.getStride();
+ auto elementValue = builder.emitElementExtract(
+ elementType, src, builder.getIntValue(builder.getUIntType(), index));
+ return findLeafValueAtOffset(
+ targetReq,
+ builder,
+ elementType,
+ elementLayout,
+ elementValue,
+ (uint32_t)(offset - elementLayout.getStride() * index));
+ }
+ break;
+ case kIROp_VectorType:
+ {
+ auto vectorType = as<IRVectorType>(dataType);
+ auto elementType = vectorType->getElementType();
+ IRSizeAndAlignment elementLayout;
+ CHECK(getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout));
+ uint32_t index =
+ elementLayout.getStride() == 0 ? 0 : (uint32_t)(offset / elementLayout.getStride());
+ auto elementValue = builder.emitElementExtract(
+ elementType, src, builder.getIntValue(builder.getUIntType(), index));
+ return findLeafValueAtOffset(
+ targetReq,
+ builder,
+ elementType,
+ elementLayout,
+ elementValue,
+ (uint32_t)(offset - elementLayout.getStride() * index));
+ }
+ break;
+ case kIROp_MatrixType:
+ {
+ // Note: this code is assuming row major odering.
+ auto matrixType = as<IRMatrixType>(dataType);
+ auto elementType = matrixType->getElementType();
+ SLANG_RELEASE_ASSERT(matrixType->getColumnCount()->op == kIROp_IntLit);
+ auto columnCount = as<IRIntLit>(matrixType->getColumnCount())->value.intVal;
+ auto rowType = builder.getVectorType(elementType, matrixType->getColumnCount());
+ IRSizeAndAlignment rowLayout;
+ CHECK(getNaturalSizeAndAlignment(targetReq, rowType, &rowLayout));
+ uint32_t rowIndex = rowLayout.getStride() == 0
+ ? 0
+ : (uint32_t)(offset / (columnCount * rowLayout.getStride()));
+ auto rowValue = builder.emitElementExtract(
+ rowType, src, builder.getIntValue(builder.getUIntType(), rowIndex));
+ return findLeafValueAtOffset(
+ targetReq,
+ builder,
+ rowType,
+ rowLayout,
+ rowValue,
+ (uint32_t)(offset - rowLayout.getStride() * rowIndex));
+ }
+ break;
+ default:
+ {
+ result.leafValue = src;
+ result.offsetInValue = offset;
+ result.valueSize = (uint32_t)layout.size;
+ return result;
+ }
+ break;
+ }
+}
+
+IRInst* extractByteAtOffset(
+ IRBuilder& builder,
+ TargetRequest* targetReq,
+ IRType* dataType,
+ IRSizeAndAlignment& layout,
+ IRInst* src,
+ uint32_t offset)
+{
+ auto leaf = findLeafValueAtOffset(targetReq, builder, dataType, layout, src, offset);
+ IRType* uintType = nullptr;
+ if (leaf.valueSize <= 4)
+ {
+ uintType = builder.getUIntType();
+ }
+ else
+ {
+ uintType = builder.getUInt64Type();
+ }
+ auto resultValue = builder.emitBitCast(uintType, leaf.leafValue);
+ if (leaf.offsetInValue != 0)
+ {
+ uint32_t shift = leaf.offsetInValue * 8;
+ resultValue = builder.emitShr(uintType, resultValue, builder.getIntValue(uintType, shift));
+
+ resultValue = builder.emitBitAnd(
+ builder.getUIntType(),
+ resultValue, builder.getIntValue(builder.getUIntType(), 0xFF));
+ }
+ return resultValue;
+}
+
+IRInst* extractMultiByteValueAtOffset(
+ IRBuilder& builder,
+ TargetRequest* targetReq,
+ IRType* dataType,
+ IRSizeAndAlignment& layout,
+ IRInst* src,
+ uint32_t size,
+ uint32_t offset)
+{
+ if (size == 1)
+ return extractByteAtOffset(builder, targetReq, dataType, layout, src, offset);
+
+ auto leaf = findLeafValueAtOffset(targetReq, builder, dataType, layout, src, offset);
+ auto resultValue = leaf.leafValue;
+ IRType* uintType = nullptr;
+ if (leaf.valueSize <= 4)
+ {
+ uintType = builder.getUIntType();
+ }
+ else
+ {
+ uintType = builder.getUInt64Type();
+ }
+ if (leaf.valueSize - leaf.offsetInValue >= size)
+ {
+ // The request value is fully contained in the found leaf element.
+ // We can proceed to extract the requested bits from the element.
+ uint32_t shift = leaf.offsetInValue * 8;
+ if (shift > 0)
+ resultValue = builder.emitShr(uintType, resultValue, builder.getIntValue(uintType, shift));
+ uint32_t bitMask = 0;
+ switch (size)
+ {
+ case 1:
+ bitMask = 0xFF;
+ break;
+ case 2:
+ bitMask = 0xFFFFF;
+ break;
+ case 3:
+ bitMask = 0xFFFFFF;
+ break;
+ case 4:
+ bitMask = 0xFFFFFFFF;
+ break;
+ default:
+ break;
+ }
+ if (leaf.valueSize != size)
+ {
+ resultValue =
+ builder.emitBitAnd(uintType, resultValue, builder.getIntValue(uintType, bitMask));
+ }
+ return resultValue;
+ }
+ else
+ {
+ // The requested value crosses the boundaries of different fields.
+ // We need to extract first and second half separately, and combine them together.
+ auto firstHalf = extractMultiByteValueAtOffset(
+ builder, targetReq, dataType, layout, src, size / 2, offset);
+ auto secondHalf = extractMultiByteValueAtOffset(
+ builder, targetReq, dataType, layout, src, size / 2, offset + size / 2);
+ uint32_t shift = (size / 2) * 8;
+ resultValue = builder.emitAdd(
+ builder.getUIntType(),
+ firstHalf,
+ builder.emitShl(
+ builder.getUIntType(),
+ secondHalf,
+ builder.getIntValue(builder.getUIntType(), shift)));
+ return resultValue;
+ }
+}
+
+IRInst* extractValueAtOffset(
+ IRBuilder& builder, TargetRequest* targetReq, IRInst* src, uint32_t offset, uint32_t size)
+{
+ auto dataType = src->getDataType();
+ IRSizeAndAlignment typeLayout;
+ SLANG_RETURN_NULL_ON_FAIL(getNaturalSizeAndAlignment(targetReq, dataType, &typeLayout));
+ if (offset + size > typeLayout.size)
+ {
+ return builder.getIntValue(builder.getIntType(), 0);
+ }
+ return extractMultiByteValueAtOffset(
+ builder, targetReq, dataType, typeLayout, src, size, offset);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-extract-value-from-type.h b/source/slang/slang-ir-extract-value-from-type.h
new file mode 100644
index 000000000..de39ee545
--- /dev/null
+++ b/source/slang/slang-ir-extract-value-from-type.h
@@ -0,0 +1,16 @@
+// slang-ir-extract-value-from-type.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-type-layout.h"
+
+namespace Slang
+{
+
+// Emit code using builder that yields an `IRInst` representing a value of `size` bytes
+// starting at `offset` in `src`. `src` must be a value of `struct`, array, vector or basic type.
+// `size` can be either 1, 2 or 4. The resulting `IRInst` value will have an `uint` type.
+IRInst* extractValueAtOffset(
+ IRBuilder& builder, TargetRequest* targetReq, IRInst* src, uint32_t offset, uint32_t size);
+
+}
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 41b004372..ed3fff2c0 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -77,6 +77,12 @@ static Result _calcNaturalArraySizeAndAlignment(
return SLANG_OK;
}
+IRIntegerValue getIntegerValueFromInst(IRInst* inst)
+{
+ SLANG_ASSERT(inst->op == kIROp_IntLit);
+ return as<IRIntLit>(inst)->value.intVal;
+}
+
static Result _calcNaturalSizeAndAlignment(
TargetRequest* target,
IRType* type,
@@ -192,6 +198,24 @@ static Result _calcNaturalSizeAndAlignment(
}
break;
+ case kIROp_MatrixType:
+ {
+ auto matType = cast<IRMatrixType>(type);
+ auto rowCount = getIntegerValueFromInst(matType->getRowCount());
+ auto colCount = getIntegerValueFromInst(matType->getColumnCount());
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.module = type->getModule();
+ sharedBuilder.session = sharedBuilder.module->getSession();
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilder;
+
+ return _calcNaturalArraySizeAndAlignment(
+ target, matType->getElementType(),
+ builder.getIntValue(builder.getUIntType(), rowCount * colCount),
+ outSizeAndAlignment);
+ }
+ break;
default:
break;
}
diff --git a/source/slang/slang-ir-lower-bit-cast.cpp b/source/slang/slang-ir-lower-bit-cast.cpp
new file mode 100644
index 000000000..ed312741e
--- /dev/null
+++ b/source/slang/slang-ir-lower-bit-cast.cpp
@@ -0,0 +1,259 @@
+#include "slang-ir-lower-bit-cast.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-extract-value-from-type.h"
+#include "slang-ir-layout.h"
+
+namespace Slang
+{
+
+struct BitCastLoweringContext
+{
+ TargetRequest* targetReq;
+ IRModule* module;
+ SharedIRBuilder sharedBuilderStorage;
+ OrderedHashSet<IRInst*> workList;
+
+ void addToWorkList(IRInst* inst)
+ {
+ for (auto ii = inst->getParent(); ii; ii = ii->getParent())
+ {
+ if (as<IRGeneric>(ii))
+ return;
+ }
+
+ if (workList.Contains(inst))
+ return;
+
+ workList.Add(inst);
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->op)
+ {
+ case kIROp_BitCast:
+ processBitCast(inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = module;
+ sharedBuilder->session = module->session;
+
+ // Deduplicate equivalent types.
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
+ addToWorkList(module->getModuleInst());
+
+ while (workList.Count() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+ }
+
+
+ // Extract an object of `type` from `offset` in `src`.
+ IRInst* readObject(IRBuilder& builder, IRInst* src, IRType* type, uint32_t offset)
+ {
+ switch (type->op)
+ {
+ case kIROp_StructType:
+ {
+ auto structType = as<IRStructType>(type);
+ List<IRInst*> fieldValues;
+ for (auto field : structType->getFields())
+ {
+ IRIntegerValue fieldOffset = 0;
+ SLANG_RELEASE_ASSERT(
+ getNaturalOffset(targetReq, field, &fieldOffset) == SLANG_OK);
+ auto fieldType = field->getFieldType();
+ auto fieldValue =
+ readObject(builder, src, fieldType, (uint32_t)(fieldOffset + offset));
+ fieldValues.add(fieldValue);
+ }
+ return builder.emitMakeStruct(structType, fieldValues);
+ }
+ break;
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(type);
+ auto arrayCount = as<IRIntLit>(arrayType->getElementCount());
+ SLANG_RELEASE_ASSERT(arrayCount && "bit_cast: array size must be fixed.");
+ List<IRInst*> elements;
+ IRSizeAndAlignment elementLayout;
+ SLANG_RELEASE_ASSERT(
+ getNaturalSizeAndAlignment(
+ targetReq, arrayType->getElementType(), &elementLayout) == SLANG_OK);
+ for (IRIntegerValue i = 0; i < arrayCount->value.intVal; i++)
+ {
+ elements.add(readObject(
+ builder,
+ src,
+ arrayType->getElementType(),
+ (uint32_t)(offset + elementLayout.getStride() * i)));
+ }
+ return builder.emitMakeArray(arrayType, (UInt)arrayCount->value.intVal, elements.getBuffer());
+ }
+ break;
+ case kIROp_VectorType:
+ {
+ auto vectorType = as<IRVectorType>(type);
+ auto elementCount = as<IRIntLit>(vectorType->getElementCount());
+ SLANG_RELEASE_ASSERT(elementCount && "bit_cast: vector size must be int literal.");
+ List<IRInst*> elements;
+ IRSizeAndAlignment elementLayout;
+ SLANG_RELEASE_ASSERT(
+ getNaturalSizeAndAlignment(
+ targetReq, vectorType->getElementType(), &elementLayout) == SLANG_OK);
+ for (IRIntegerValue i = 0; i < elementCount->value.intVal; i++)
+ {
+ elements.add(readObject(
+ builder,
+ src,
+ vectorType->getElementType(),
+ (uint32_t)(offset + elementLayout.getStride() * i)));
+ }
+ return builder.emitMakeVector(
+ vectorType, (UInt)elementCount->value.intVal, elements.getBuffer());
+ }
+ break;
+ case kIROp_MatrixType:
+ {
+ // Assuming row-major order
+ auto matrixType = as<IRMatrixType>(type);
+ auto elementCount = as<IRIntLit>(matrixType->getRowCount());
+ SLANG_RELEASE_ASSERT(
+ elementCount && "bit_cast: vector size must be int literal.");
+ List<IRInst*> elements;
+ auto elementType = builder.getVectorType(
+ matrixType->getElementType(), matrixType->getColumnCount());
+ IRSizeAndAlignment elementLayout;
+ SLANG_RELEASE_ASSERT(
+ getNaturalSizeAndAlignment(targetReq, elementType, &elementLayout) == SLANG_OK);
+ for (IRIntegerValue i = 0; i < elementCount->value.intVal; i++)
+ {
+ elements.add(readObject(
+ builder,
+ src,
+ elementType,
+ (uint32_t)(offset + elementLayout.getStride() * i)));
+ }
+ return builder.emitMakeMatrix(
+ matrixType, (UInt)elementCount->value.intVal, elements.getBuffer());
+ }
+ break;
+ case kIROp_HalfType:
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ {
+ auto object = extractValueAtOffset(builder, targetReq, src, offset, 2);
+ return builder.emitBitCast(type, object);
+ }
+ break;
+ case kIROp_IntType:
+ case kIROp_UIntType:
+ case kIROp_FloatType:
+ case kIROp_BoolType:
+ {
+ auto object = extractValueAtOffset(builder, targetReq, src, offset, 4);
+ return builder.emitBitCast(type, object);
+ }
+ break;
+ case kIROp_DoubleType:
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+ case kIROp_RawPointerType:
+ {
+ auto low = extractValueAtOffset(builder, targetReq, src, offset, 4);
+ auto high = extractValueAtOffset(builder, targetReq, src, offset + 4, 4);
+ auto combined = builder.emitAdd(builder.getUInt64Type(),
+ low,
+ builder.emitShl(
+ builder.getUInt64Type(),
+ high,
+ builder.getIntValue(builder.getUIntType(), 32)));
+ if (type->op == kIROp_UInt64Type)
+ return combined;
+ return builder.emitBitCast(type, combined);
+ }
+ break;
+ case kIROp_UInt8Type:
+ case kIROp_Int8Type:
+ {
+ auto object = extractValueAtOffset(builder, targetReq, src, offset, 1);
+ return builder.emitBitCast(type, object);
+ }
+ break;
+ default:
+ {
+ SLANG_UNEXPECTED("Unable to generate bit_cast code for the given type");
+ }
+ break;
+ }
+ }
+
+ void processBitCast(IRInst* inst)
+ {
+ auto operand = inst->getOperand(0);
+ auto fromType = operand->getDataType();
+ auto toType = inst->getDataType();
+ if (as<IRBasicType>(fromType) != nullptr && as<IRBasicType>(toType) != nullptr)
+ {
+ // Both fromType and toType are basic types, no processing needed.
+ return;
+ }
+ // Ignore cases we cannot handle yet.
+ if (as<IRPtrType>(fromType) || as<IRPtrType>(toType))
+ {
+ return;
+ }
+ if (as<IRRawPointerType>(fromType) || as<IRRawPointerType>(toType))
+ {
+ return;
+ }
+ if (as<IRResourceTypeBase>(fromType) || as<IRResourceTypeBase>(toType))
+ {
+ return;
+ }
+ if (as<IRPointerLikeType>(fromType) || as<IRPointerLikeType>(toType))
+ {
+ return;
+ }
+ if (as<IRSamplerStateTypeBase>(fromType) || as<IRSamplerStateTypeBase>(toType))
+ {
+ return;
+ }
+ // Enumerate all fields in to-type and obtain its value from operand object.
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(inst);
+ auto finalObject = readObject(builder, operand, toType, 0);
+ inst->replaceUsesWith(finalObject);
+ inst->removeAndDeallocate();
+ }
+};
+
+void lowerBitCast(TargetRequest* targetReq, IRModule* module)
+{
+ BitCastLoweringContext context;
+ context.module = module;
+ context.targetReq = targetReq;
+ context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-lower-bit-cast.h b/source/slang/slang-ir-lower-bit-cast.h
new file mode 100644
index 000000000..4e93dbe4f
--- /dev/null
+++ b/source/slang/slang-ir-lower-bit-cast.h
@@ -0,0 +1,15 @@
+// slang-ir-lower-bit-cast.h
+#pragma once
+
+// This file defines an IR pass that lowers a BitCast<T>(U) operation, where T and U are struct types,
+// into a series of bit-cast operations on basic-typed elements.
+
+namespace Slang
+{
+
+struct IRModule;
+class TargetRequest;
+
+void lowerBitCast(TargetRequest* targetReq, IRModule* module);
+
+}
diff --git a/tests/language-feature/bit-cast/struct-bit-cast.slang b/tests/language-feature/bit-cast/struct-bit-cast.slang
new file mode 100644
index 000000000..9c4a039c0
--- /dev/null
+++ b/tests/language-feature/bit-cast/struct-bit-cast.slang
@@ -0,0 +1,68 @@
+// struct-bit-cast.slang
+
+//TEST(compute):COMPARE_COMPUTE: -shaderobj
+
+// Test that bit_cast works for bit-reinterpreting one struct type as another.
+
+struct Foo
+{
+ uint a;
+ float b;
+ float2 fvec;
+}
+
+struct Inner
+{
+ int v;
+ uint s;
+}
+
+struct Bar
+{
+ int u;
+ Inner i;
+ uint t;
+}
+
+int test0(int val)
+{
+ Bar b;
+ b.u = val;
+ b.i.v = asint(2.0f);
+ b.i.s = asuint(1.25);
+ b.t = asuint(0.25);
+ Foo f = bit_cast<Foo, Bar>(b);
+ return f.a + (int)f.b + int(float(f.fvec.x / f.fvec.y)); // val + 2 + 5
+}
+
+struct Smaller
+{
+ int s;
+}
+
+struct Larger
+{
+ int x, y;
+}
+
+int test1()
+{
+ Smaller s = {1};
+ int v0 = bit_cast<Larger, Smaller>(s).y; // 0.
+ Larger l = {1, 2};
+ int v1 = bit_cast<Smaller, Larger>(l).s; // 1.
+ return v0 + v1;
+}
+
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test0(inVal) + test1();
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt b/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt
new file mode 100644
index 000000000..9ace6947d
--- /dev/null
+++ b/tests/language-feature/bit-cast/struct-bit-cast.slang.expected.txt
@@ -0,0 +1,4 @@
+8
+9
+A
+B