summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-09-29 17:45:08 -0700
committerGitHub <noreply@github.com>2025-09-30 00:45:08 +0000
commita6deb5ed82cb8fc6b4f4c5c5fee264e09f97ff89 (patch)
tree1c374bd52498cad2e142e3c7f5482fd42dca966f
parent2827c94de5901cac42a67f73a78ab2548771b28c (diff)
Rewriting the lower-buffer-element-type pass to avoid unnecessary packing/unpacking. (#8526)
Part of the effort to improve the performance of generated SPIRV code. The existing lower-buffer-element-type pass works by loading the entire buffer element content from memory, and translate it to logical type stored in a local variable at the earliest reference of a buffer handle. This means that is can generate inefficient code that reads more than necessary. Consider this example: ``` struct BigStruct { bool values[1024]; } ConstantBuffer<BigStruct> cb; void test(BigStruct v) { if (v.values[0]) { printf("ok"); } } [numthreads(1,1,1)] void computeMain() { test(cb); } ``` In IR, the `computeMain` function before lower-buffer-element-type pass is something like following: ``` func test: %v = param : BigStruct %barr = fieldExtract(%v, "values") %element = elementExtract(%barr, 0) ... // uses %element func computeMain: %v = load(cb) call %test %v ``` The existing lower-buffer-element-type pass will rewrite the bool array in `BigStruct` into `int` array so it is legal in SPIRV. However, it does so by inserting the translation on the first `load` of the constant buffer: ``` struct BigStruct_std430 { int values[1024]; } var cb : ConstantBuffer<BigStruct_std430>; func computeMain: %tmpVar : var<BigStruct> call %unpackStorage(%tmpVar, cb) %v : BigStruct = load %tmpVar call %test %v ``` This means that the entire array will be loaded and translated to int, before calling `test`, which only uses one element. It turns out that the downstream compiler isn't always able to optimize out this inefficient translation/copy. This PR completely rewrites the way buffer-element-type lowering is handled to avoid producing this inefficient code. It works in two parts: first we turn on the `transformParamsToConstRef` pass for SPIRV target as well, so we will translate the `test` function to take the `v` parameter as `constref`. The second part is a redesigned buffer-element-type pass that defers the storage-type to logical-type translation until a value is actually used by a `load` instruction. In this example, after `transformParamsToConstRef`, the IR is: ``` func test: %v = param : ConstRef<BigStruct> %barr = fieldAddr(%v, "values") %elementPtr = elementAddr(%barr, 0) %element = load(%elementPtr) ... // uses %element func computeMain: call %test %cb ``` The new `buffer-element-type-lowering` pass will take this IR, and insert translation at latest possible time across the entire call graph, and translate the IR into: ``` func test: %v = param : ConstRef<BigStruct_std430> %barr = fieldAddr(%v, "values") %elementPtr : ptr<int> = elementAddr(%barr, 0) %element_int = load(%elementPtr) %element = cast(%element_int) : %bool ... // uses %element func computeMain: call %test %cb ``` In this new IR, there is no longer a load and conversion of the entire array. See new comment in `slang-ir-lower-buffer-element-type.cpp` for more details of how the pass works. This PR also address many other issues surfaced by turning on `transformParamsToConstRef` pass on SPIRV backend. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang-glslang/slang-glslang.cpp7
-rw-r--r--source/slang/core.meta.slang1
-rw-r--r--source/slang/slang-emit-spirv.cpp2
-rw-r--r--source/slang/slang-emit.cpp37
-rw-r--r--source/slang/slang-ir-defer-buffer-load.cpp38
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp110
-rw-r--r--source/slang/slang-ir-insts-stable-names.lua6
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-insts.lua12
-rw-r--r--source/slang/slang-ir-legalize-types.cpp10
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp1507
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.h2
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp17
-rw-r--r--source/slang/slang-ir-metal-legalize.h2
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp3
-rw-r--r--source/slang/slang-ir-specialize-address-space.cpp39
-rw-r--r--source/slang/slang-ir-specialize-address-space.h7
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp46
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.cpp142
-rw-r--r--source/slang/slang-ir-undo-param-copy.cpp14
-rw-r--r--source/slang/slang-ir-util.cpp77
-rw-r--r--source/slang/slang-ir-util.h20
-rw-r--r--source/slang/slang-ir.cpp69
-rw-r--r--source/slang/slang-ir.h10
-rw-r--r--source/slang/slang-lower-to-ir.cpp21
-rw-r--r--tests/compute/byte-address-buffer-array.slang4
-rw-r--r--tests/optimization/arrray-storage-lowering.slang42
-rw-r--r--tests/optimization/get-array-element.slang17
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex.slang7
-rw-r--r--tests/spirv/aligned-load-store.slang2
-rw-r--r--tests/spirv/buffer-pointer-matrix-layout.slang69
-rw-r--r--tests/spirv/geometry-shader-sub-func.slang2
-rw-r--r--tests/spirv/large-struct.slang2
-rw-r--r--tests/spirv/pointer-2.slang6
-rw-r--r--tests/spirv/spec-constant-operations.slang2
-rw-r--r--tests/spirv/spirv-debug-break.slang2
-rw-r--r--tools/render-test/shader-input-layout.cpp7
37 files changed, 1754 insertions, 621 deletions
diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp
index 56be8b042..1c91a97bc 100644
--- a/source/slang-glslang/slang-glslang.cpp
+++ b/source/slang-glslang/slang-glslang.cpp
@@ -241,7 +241,12 @@ extern "C"
#endif
bool glslang_disassembleSPIRV(const uint32_t* contents, int contentsSize)
{
- return glslang_disassembleSPIRVWithResult(contents, contentsSize, nullptr);
+ char* result = nullptr;
+ auto succ = glslang_disassembleSPIRVWithResult(contents, contentsSize, &result);
+ if (result)
+ fprintf(stdout, "%s\n", result);
+ delete result;
+ return succ;
}
// Apply the SPIRV-Tools optimizer to generated SPIR-V based on the desired optimization level
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 43658563e..54e693117 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1312,7 +1312,6 @@ enum Access : uint64_t
/// @param T The type of the value pointed to.
/// @remarks `T* val` is equivalent to `Ptr<T> val`.
__magic_type(PtrType)
-__intrinsic_type($(kIROp_PtrType))
struct Ptr<
T,
Access access = Access::ReadWrite,
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 7b1bd66d8..ea49cb08c 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -7038,7 +7038,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())),
builder.getIntType());
SLANG_ASSERT(as<IRPtrTypeBase>(fieldAddress->getFullType()));
- return emitOpInBoundsAccessChain(
+ return emitOpAccessChain(
parent,
fieldAddress,
fieldAddress->getFullType(),
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index e1689ccfc..f1cc6090d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -97,6 +97,7 @@
#include "slang-ir-restructure.h"
#include "slang-ir-sccp.h"
#include "slang-ir-simplify-for-emit.h"
+#include "slang-ir-specialize-address-space.h"
#include "slang-ir-specialize-arrays.h"
#include "slang-ir-specialize-buffer-load-arg.h"
#include "slang-ir-specialize-matrix-layout.h"
@@ -1715,6 +1716,7 @@ Result linkAndOptimizeIR(
if (targetProgram->getOptionSet().getBoolOption(
CompilerOptionName::EnableExperimentalPasses))
introduceExplicitGlobalContext(irModule, target);
+ transformParamsToConstRef(irModule, codeGenContext->getSink());
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "EXPLICIT GLOBAL CONTEXT INTRODUCED");
#endif
@@ -1812,11 +1814,11 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.meshOutput)
legalizeMeshOutputTypes(irModule);
- BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
- bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer =
- isWGPUTarget(targetRequest);
- lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions);
- performForceInlining(irModule);
+
+ // Lower all bit_cast operations on complex types into leaf-level
+ // bit_cast on basic types.
+ if (requiredLoweringPassSet.bitcast)
+ lowerBitCast(targetProgram, irModule, sink);
// Rewrite functions that return arrays to return them via `out` parameter,
// since our target languages doesn't allow returning arrays.
@@ -1832,13 +1834,28 @@ Result linkAndOptimizeIR(
rcpWOfPositionInput(irModule);
}
- // Lower all bit_cast operations on complex types into leaf-level
- // bit_cast on basic types.
- if (requiredLoweringPassSet.bitcast)
- lowerBitCast(targetProgram, irModule, sink);
-
bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly();
+ BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
+ bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer =
+ isWGPUTarget(targetRequest);
+ lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions);
+
+ // If we are generating code for glsl or metal, perform address space propagation now.
+ // For SPIRV, we will do that during spirv legalization that happens after
+ // `linkAndOptimizeIR`.
+ if (target == CodeGenTarget::GLSL)
+ {
+ NoOpInitialAddressSpaceAssigner addrSpaceAssigner;
+ specializeAddressSpace(irModule, &addrSpaceAssigner);
+ }
+ else if (isMetalTarget(targetRequest))
+ {
+ specializeAddressSpaceForMetal(irModule);
+ }
+
+ performForceInlining(irModule);
+
if (emitSpirvDirectly)
{
performIntrinsicFunctionInlining(irModule);
diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp
index e71892fe5..51c6a161b 100644
--- a/source/slang/slang-ir-defer-buffer-load.cpp
+++ b/source/slang/slang-ir-defer-buffer-load.cpp
@@ -67,38 +67,6 @@ struct DeferBufferLoadContext
return result;
}
- static bool isImmutableLocation(IRInst* loc)
- {
- switch (loc->getOp())
- {
- case kIROp_GetStructuredBufferPtr:
- case kIROp_ImageSubscript:
- return isImmutableLocation(loc->getOperand(0));
- default:
- break;
- }
-
- auto type = loc->getDataType();
- if (!type)
- return false;
-
- switch (type->getOp())
- {
- case kIROp_HLSLStructuredBufferType:
- case kIROp_HLSLByteAddressBufferType:
- case kIROp_ConstantBufferType:
- case kIROp_ParameterBlockType:
- return true;
- default:
- break;
- }
-
- if (auto textureType = as<IRTextureType>(type))
- return textureType->getAccess() == SLANG_RESOURCE_ACCESS_READ;
-
- return false;
- }
-
static bool isImmutableBufferLoad(IRInst* inst)
{
// Note: we cannot defer loads from RWStructuredBuffer because there can be other
@@ -111,7 +79,7 @@ struct DeferBufferLoadContext
case kIROp_Load:
{
auto rootAddr = getRootAddr(inst->getOperand(0));
- return isImmutableLocation(rootAddr);
+ return isPointerToImmutableLocation(rootAddr);
}
default:
return false;
@@ -132,14 +100,14 @@ struct DeferBufferLoadContext
align = load->findAttr<IRAlignedAttr>();
if (!as<IRModuleInst>(ptr->getParent()))
{
- builder.setInsertAfter(ptr);
+ setInsertAfterOrdinaryInst(&builder, ptr);
IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType());
result = builder.emitLoad(valueType, ptr, align);
mapPtrToValue[ptr] = result;
}
else
{
- builder.setInsertBefore(loadInst);
+ setInsertBeforeOrdinaryInst(&builder, loadInst);
IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType());
result = builder.emitLoad(valueType, ptr, align);
// Since we are inserting the load in a local scope, we can't register
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index a2f56cf7d..a79ca2379 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1079,6 +1079,12 @@ IRInst* getOrCreateBuiltinParamForHullShader(
if (sysAttr->getName().caseInsensitiveEquals(builtinSemantic))
{
outputControlPointIdParam = param;
+ if (as<IRPtrTypeBase>(outputControlPointIdParam->getDataType()))
+ {
+ IRBuilder builder(param);
+ setInsertAfterOrdinaryInst(&builder, param);
+ outputControlPointIdParam = builder.emitLoad(param);
+ }
break;
}
}
@@ -2348,11 +2354,11 @@ ScalarizedVal getSubscriptVal(
auto inputAdapter = val.impl.as<ScalarizedTypeAdapterValImpl>();
RefPtr<ScalarizedTypeAdapterValImpl> resultAdapter = new ScalarizedTypeAdapterValImpl();
- resultAdapter->pretendType = inputAdapter->pretendType;
- resultAdapter->actualType = inputAdapter->actualType;
+ resultAdapter->pretendType = elementType;
+ resultAdapter->actualType = getElementType(*builder, inputAdapter->actualType);
resultAdapter->val =
- getSubscriptVal(builder, inputAdapter->actualType, inputAdapter->val, indexVal);
+ getSubscriptVal(builder, resultAdapter->actualType, inputAdapter->val, indexVal);
return ScalarizedVal::typeAdapter(resultAdapter);
}
@@ -3127,7 +3133,7 @@ void tryReplaceUsesOfStageInput(
{
auto user = use->getUser();
IRBuilder builder(user);
- builder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
builder.replaceOperand(use, val.irValue);
});
}
@@ -3155,7 +3161,7 @@ void tryReplaceUsesOfStageInput(
return;
}
IRBuilder builder(user);
- builder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
if (needMaterialize)
{
auto materializedVal = materializeValue(&builder, val);
@@ -3176,22 +3182,50 @@ void tryReplaceUsesOfStageInput(
{
auto user = use->getUser();
IRBuilder builder(user);
- builder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(val.impl);
- auto materializedInner = materializeValue(&builder, typeAdapter->val);
- auto adapted = adaptType(
- &builder,
- materializedInner,
- typeAdapter->pretendType,
- typeAdapter->actualType);
- if (user->getOp() == kIROp_Load)
- {
- user->replaceUsesWith(adapted.irValue);
- user->removeAndDeallocate();
- }
- else
+ switch (user->getOp())
{
- use->set(adapted.irValue);
+ case kIROp_Load:
+ {
+ auto materialized = materializeValue(&builder, val);
+ user->replaceUsesWith(materialized);
+ user->removeAndDeallocate();
+ }
+ break;
+ case kIROp_GetElementPtr:
+ {
+ auto targetType = typeAdapter->pretendType;
+ auto elementType = getElementType(builder, targetType);
+ SLANG_ASSERT(elementType);
+ auto subscriptVal = getSubscriptVal(
+ &builder,
+ (IRType*)elementType,
+ val,
+ user->getOperand(1));
+ tryReplaceUsesOfStageInput(context, subscriptVal, user);
+ }
+ break;
+ case kIROp_FieldAddress:
+ {
+ auto targetType = as<IRStructType>(typeAdapter->pretendType);
+ SLANG_ASSERT(targetType);
+ auto subscriptVal = extractField(
+ &builder,
+ val,
+ kMaxUInt,
+ (IRStructKey*)user->getOperand(1));
+ tryReplaceUsesOfStageInput(context, subscriptVal, user);
+ }
+ break;
+ default:
+ {
+ auto materialized = materializeValue(&builder, val);
+ auto tmpVar = builder.emitVar(materialized->getDataType());
+ builder.emitStore(tmpVar, materialized);
+ use->set(tmpVar);
+ }
+ break;
}
});
}
@@ -3205,13 +3239,13 @@ void tryReplaceUsesOfStageInput(
auto arrayIndexImpl = as<ScalarizedArrayIndexValImpl>(val.impl);
auto user = use->getUser();
IRBuilder builder(user);
- builder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
auto subscriptVal = getSubscriptVal(
&builder,
arrayIndexImpl->elementType,
arrayIndexImpl->arrayVal,
arrayIndexImpl->index);
- builder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
auto materializedInner = materializeValue(&builder, subscriptVal);
if (user->getOp() == kIROp_Load)
{
@@ -3220,7 +3254,9 @@ void tryReplaceUsesOfStageInput(
}
else
{
- use->set(materializedInner);
+ auto tmpVar = builder.emitVar(materializedInner->getDataType());
+ builder.emitStore(tmpVar, materializedInner);
+ use->set(tmpVar);
}
});
break;
@@ -3233,6 +3269,9 @@ void tryReplaceUsesOfStageInput(
[&](IRUse* use)
{
auto user = use->getUser();
+ IRBuilder builder(user);
+ setInsertBeforeOrdinaryInst(&builder, user);
+
switch (user->getOp())
{
case kIROp_FieldExtract:
@@ -3270,10 +3309,20 @@ void tryReplaceUsesOfStageInput(
}
}
break;
+ case kIROp_GetElementPtr:
+ {
+ auto arrayType = as<IRArrayTypeBase>(tupleVal->type);
+ SLANG_ASSERT(arrayType);
+ auto subscriptVal = getSubscriptVal(
+ &builder,
+ (IRType*)arrayType->getElementType(),
+ val,
+ user->getOperand(1));
+ tryReplaceUsesOfStageInput(context, subscriptVal, user);
+ }
+ break;
case kIROp_Load:
{
- IRBuilder builder(user);
- builder.setInsertBefore(user);
auto materializedVal = materializeTupleValue(&builder, val);
user->replaceUsesWith(materializedVal);
user->removeAndDeallocate();
@@ -3449,7 +3498,7 @@ void legalizeEntryPointParameterForGLSL(
// Okay, we have a declaration, and we want to modify it!
- builder->setInsertBefore(ii);
+ setInsertBeforeOrdinaryInst(builder, ii);
assign(builder, globalOutputVal, ScalarizedVal::value(ii->getOperand(2)));
}
@@ -3768,12 +3817,13 @@ void legalizeEntryPointParameterForGLSL(
blockToMaterialized.tryGetValue(callingBlock, materialized);
if (!found)
{
- replaceBuilder.setInsertBefore(callingBlock->getFirstInst());
+ replaceBuilder.setInsertBefore(
+ callingBlock->getFirstOrdinaryInst());
materialized = materializeValue(&replaceBuilder, globalValue);
blockToMaterialized.set(callingBlock, materialized);
}
- replaceBuilder.setInsertBefore(user);
+ setInsertBeforeOrdinaryInst(builder, user);
auto field =
replaceBuilder.emitFieldExtract(globalVarType, materialized, key);
replaceBuilder.replaceOperand(operandUse, field);
@@ -3888,7 +3938,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
{
rayPayloadCounter++;
}
- builder.setInsertBefore(inst);
+ setInsertBeforeOrdinaryInst(&builder, inst);
location = builder.getIntValue(builder.getIntType(), rayPayloadCounter);
decor->setOperand(0, location);
rayPayloadCounter++;
@@ -3902,7 +3952,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
{
callablePayloadCounter++;
}
- builder.setInsertBefore(inst);
+ setInsertBeforeOrdinaryInst(&builder, inst);
location = builder.getIntValue(builder.getIntType(), callablePayloadCounter);
decor->setOperand(0, location);
callablePayloadCounter++;
@@ -3915,7 +3965,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
{
hitObjectAttributeCounter++;
}
- builder.setInsertBefore(inst);
+ setInsertBeforeOrdinaryInst(&builder, inst);
location = builder.getIntValue(builder.getIntType(), hitObjectAttributeCounter);
decor->setOperand(0, location);
hitObjectAttributeCounter++;
diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua
index b2c216bb4..25d54eb04 100644
--- a/source/slang/slang-ir-insts-stable-names.lua
+++ b/source/slang/slang-ir-insts-stable-names.lua
@@ -670,6 +670,10 @@ return {
["SPIRVAsmOperand.__imageType"] = 666,
["SPIRVAsmOperand.__sampledImageType"] = 667,
["Type.CLayout"] = 668,
- ["CastUInt64ToDescriptorHandle"] = 669,
+ ["CastUInt64ToDescriptorHandle"] = 669,
["CastDescriptorHandleToUInt64"] = 670,
+ ["CastStorageToLogicalBase.CastStorageToLogical"] = 671,
+ ["CastStorageToLogicalBase.CastStorageToLogicalDeref"] = 672,
+ ["Decoration.DisableCopyEliminationDecoration"] = 673,
+ ["Decoration.TempCallArgImmutableVar"] = 674,
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index e3014119c..c64f65ccb 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3256,6 +3256,14 @@ struct IRCastFloatToInt : IRInst
};
FIDDLE()
+struct IRCastStorageToLogicalBase : IRInst
+{
+ FIDDLE(baseInst())
+ IRInst* getVal() { return getOperand(0); }
+ IRInst* getBufferType() { return getOperand(1); }
+};
+
+FIDDLE()
struct IRDebugSource : IRInst
{
FIDDLE(leafInst())
@@ -4573,6 +4581,12 @@ public:
IRInst* emitCastPtrToInt(IRInst* val);
IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val);
+ IRInst* emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType);
+ IRCastStorageToLogicalDeref* emitCastStorageToLogicalDeref(
+ IRType* type,
+ IRInst* val,
+ IRInst* bufferType);
+
IRGlobalConstant* emitGlobalConstant(IRType* type);
IRGlobalConstant* emitGlobalConstant(IRType* type, IRInst* val);
diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua
index 0ad02b87c..5f54707a1 100644
--- a/source/slang/slang-ir-insts.lua
+++ b/source/slang/slang-ir-insts.lua
@@ -1473,6 +1473,7 @@ local insts = {
struct_name = "RequireFullQuadsDecoration",
},
},
+ { TempCallArgImmutableVar = { struct_name = "TempCallArgImmutableVarDecoration" } },
{ TempCallArgVar = { struct_name = "TempCallArgVarDecoration" } },
{
nonCopyable = {
@@ -1480,6 +1481,7 @@ local insts = {
struct_name = "NonCopyableTypeDecoration",
},
},
+ { DisableCopyEliminationDecoration = {} },
{
DynamicUniform = {
-- Marks a value to be dynamically uniform.
@@ -1891,6 +1893,16 @@ local insts = {
{ EnumCast = { min_operands = 1 } },
{ CastUInt2ToDescriptorHandle = { min_operands = 1 } },
{ CastDescriptorHandleToUInt2 = { min_operands = 1 } },
+ -- Represents a psuedo cast to convert between a logical type (user declared) and a storage Type
+ -- (valid in buffer locations). The operand can either be a value or an address.
+ {
+ CastStorageToLogicalBase =
+ {
+ min_operands = 2, struct_name = "CastStorageToLogicalBase",
+ { CastStorageToLogical = { min_operands = 2, struct_name = "CastStorageToLogical" } },
+ { CastStorageToLogicalDeref = { min_operands = 2, struct_name = "CastStorageToLogicalDeref" } },
+ }
+ },
{ CastUInt64ToDescriptorHandle = { min_operands = 1 } },
{ CastDescriptorHandleToUInt64 = { min_operands = 1 } },
-- Represents a no-op cast to convert a resource pointer to a resource on targets where the resource handles are
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 085c3d933..27abdeaf0 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2344,6 +2344,10 @@ static LegalVal legalizeCoopMatMapElementIFunc(
static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst)
{
+ LegalVal legalVal;
+ if (context->mapValToLegalVal.tryGetValue(inst, legalVal))
+ return legalVal;
+
// Any additional instructions we need to emit
// in the process of legalizing `inst` should
// by default be insertied right before `inst`.
@@ -2463,7 +2467,7 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst)
auto builder = context->builder;
builder->setInsertBefore(inst);
- LegalVal legalVal = legalizeInst(context, inst, legalType, legalArgs.getArrayView().arrayView);
+ legalVal = legalizeInst(context, inst, legalType, legalArgs.getArrayView().arrayView);
if (legalVal.flavor == LegalVal::Flavor::simple)
{
@@ -2789,7 +2793,9 @@ private:
static LegalVal legalizeFunc(IRTypeLegalizationContext* context, IRFunc* irFunc)
{
LegalFuncBuilder builder(context);
- return builder.build(irFunc);
+ auto legalVal = builder.build(irFunc);
+ registerLegalizedValue(context, irFunc, legalVal);
+ return legalVal;
}
static void cloneDecorationToVar(IRInst* srcInst, IRInst* varInst)
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 056ee6244..128502bd8 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -6,6 +6,211 @@
#include "slang-ir-util.h"
#include "slang-ir.h"
+/// This file implements an important IR transformation pass in the Slang compiler
+/// that rewrites buffer element types into valid storage types, a.k.a physical types
+/// in SPIRV terminology.
+///
+/// Many of our targets have special restrictions on what is allowed to be used as a
+/// buffer element. Examples are:
+/// - In HLSL and SPIRV, if you have ConstantBuffer<T>, T must be a struct.
+/// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside
+/// buffers. bool vectors and matrices needs to be lowered into arrays.
+/// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have
+/// explicit offset. But if it is used to declare a local variable, then it cannot
+/// have explicit member offset. This means that we cannot use the same `Foo` struct
+/// inside a `StructuredBuffer<Foo>` and also use it to declare a local variable.
+///
+/// We use the terms "physical", "storage", or "lowered" types to refer to types that
+/// are legal to use as buffer elements. In contrast, the terms "original" or "logical"
+/// refers to types that are declared by the user in its original form.
+/// For example, `bool4` is a "logical" type, and its lowered type is `int4`.
+///
+///
+/// # Algorithm Overview
+/// ----------------------
+///
+/// This pass performs the transformation to create one "storage" type for each type that
+/// are used in each kind of buffer. For example, if user defined `Foo`, and used it in
+/// `ConstantBuffer<Foo>` and `StructuredBuffer<Foo>` and is targeting SPIRV, this pass will
+/// create `Foo_std140` and `Foo_std430` types, and update the buffer to be
+/// `ConstantBuffer<Foo_std140>` and `StructuredBuffer<Foo_std430>`.
+///
+/// The pass will rewrite all the code that uses this buffers, and insert translations between
+/// Foo_std140/Foo_std430 and Foo to keep types consistent.
+///
+/// For example, given:
+/// ```
+/// struct Foo {
+/// bool4x4 v;
+/// }
+/// ConstantBuffer<Foo> cb;
+/// bool test(Foo f) {
+/// return f.v[0][1];
+/// }
+/// void main() { test(cb); }
+/// ```
+///
+/// This pass will rewrite it as:
+/// ```
+/// struct Foo {
+/// bool4x4 v;
+/// }
+/// struct Foo_std140 {
+/// Matrix_bool4x4_std140 v;
+/// };
+/// struct Matrix_bool4x4_std140 {
+/// int4 values[4];
+/// };
+/// ConstantBuffer<Foo_std140> cb;
+/// bool test_1(Foo_std140 f) {
+/// return f.v.values[0][1];
+/// }
+/// void main() { test_1(cb); }
+/// ```
+///
+/// Note that the one important optimization here is we will defer the translation from
+/// storage type to logical type at latest possible time. In the example above, we could
+/// have loaded `cb` and then immediately translate it into `Foo` and call `test` with
+/// the translated value. However that can lead to code that create unnecessary copies
+/// that can't always be removed by the downstream compiler, particulary if there are
+/// arrays whose element type needs non-trivial translation.
+///
+/// To avoid the performance issue, we will defer this translation until a logical value
+/// is actually needed. This is done by pushing the translation to the use sites, and
+/// across function call boundaries, specializing any functions being called along the
+/// chain. This case, since we are calling `test()` from `main()` with `Foo_std140`, instead
+/// of converting the `Foo_std140` to `Foo` before the call, we create a specialization
+/// of `test` that accepts `Foo_std140` instead.
+///
+/// To enable this interprecedural transformation, the pass is organized as two phases:
+/// 1. Create lowered / storage types for all buffer element types, and update
+/// global buffer declarations to use storage types. This is implemented in `processModule()`
+/// 2. Insert a `CastStorageToLogical(loweredBuffer)` inst, and replace all uses of
+/// `loweredBuffer` with the cast inst. This is implemented in `processModule()`
+/// 3. Push the `CastStorageToLogical` insts to as late as possible, which means if we see
+/// `FieldAddress(CastStorageToLogical(storageAddr), memberKey)`, we should translate
+/// it into `CastStorageToLogical(FieldAddress(storageAddr, memberKey)`.
+/// If we see a `CastStorageToLogical` inst being used as argument to call a function `f`,
+/// specialize `f` to take a pointer to the storage type instead, and insert a
+/// `CastStorageToLogical(param)` to convert the param type to logical type at the
+/// beginning of the specialized function. (implemented in `deferStorageToLogicalCasts()`)
+///
+/// Repeat step 2 and 3 until no more changes can be made, then proceed to step 4.
+///
+/// 4. Materialize all remaining `CastStorageToLogical(addr)` by replacing all `load` of such
+/// cast insts with `call unpackStorage(addr)`, where `unpackStorage` is a function we
+/// synthesize that reads from an address of a storage type and returns a logical type;
+/// and replacing all `store(CastStorageToLogical(addr), value)` with `packStorage(addr, value)`,
+/// where `packStorage` is a function we synthesis that writes a logical value into a storage
+/// addr. This is implemented in `materializeStorageToLogicalCasts()`.
+///
+/// That's the main idea of the pass.
+///
+/// # Propagating through SSA values
+///
+/// Note that `kIROp_CastStorageToLogical` is a pseudo instruction introduced in this pass that
+/// has the semantics of "converting a pointer to a storage value into a pointer to a logical
+/// value". A dual of this inst is `kIROp_CastStorageToLogicalDeref`, which has an additional
+/// builtin "load" semantic. That is, given `Ptr<StorageType> addr`, `CastStorageToLogical(addr)`
+/// will have type `Ptr<LogicalType>`, and `CastStorageToLogicalDeref(addr)` will have type
+/// `LogicalType`. In other words, `CastStorageToLogicalDeref(addr)` is equivalent to
+/// `load(CastStorageToLogical(addr))`.
+///
+/// The `CastStorageToLogicalDeref` pseudo inst is needed to push defer through `load`s.
+/// Consider the following example:
+/// ```
+/// ptr : StorageType* = ...
+/// lptr : LogicalType* = CastStorageToLogical(ptr);
+/// l = load(lptr)
+/// m = fieldExtract(l, member)
+/// call f, m
+/// ```
+/// In this case, only l.member is used, so we should avoid translating other unrelated members
+/// from storage type to logical type. To achieve this we must be able to push the
+/// `CastStorageToLogical` operation beyond the `load`. The steps to achieve this are:
+/// 1. we process `lptr` inst by inspecting its users. We find that a `load` (l) uses it.
+/// 2. replace the `load` with `CastStorageToLogicalDeref(ptr)`, the IR become:
+/// ```
+/// ptr : StorageType* = ...
+/// l_1 = CastStorageToLogicalDeref(ptr);
+/// m = fieldExtract(l_1, member);
+/// call f, m
+/// ```
+/// 3. push the new `l_1` inst to worklist, and when it gets processed, we continue to inspect
+/// its users, and find that it is being used by `fieldExtract`. We will rewrite the
+/// `fieldExtract` into `CastStorageToLogicalDeref(fieldAddr(ptr, member))`, and the IR become:
+/// ```
+/// ptr : StorageType* = ...
+/// m_ptr = FieldAddr(ptr, member)
+/// m = CastStorageToLogicalDeref(m_ptr);
+/// call f, m
+/// ```
+/// 4. Since there are no more uses of `m` that can be translated, stop. Note that it is possible
+/// to continue specializing `f` and replace its first parameter's type to storage type. However
+/// this implementation currently does not specialize functions whose parameter type is not a
+/// pointer/reference type. When we target SPIRV, we will already be running the
+/// `transformParamsToConstRef` pass that would have converted `f` to take in `ConstRef<T>`.
+/// In this case, the initial IR would be in the form of
+/// ```
+/// ptr : StorageType* = ...
+/// lptr : LogicalType* = CastStorageToLogical(ptr);
+/// l = load(lptr)
+/// m = fieldExtract(l, member)
+/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]]
+/// store tmpVar, m
+/// call f, tmpVar
+/// ```
+/// To allow us to remove the `tmpVar` store introduced during `transformParamsToConstRef`,
+/// this pass also handles the propagation through temp var stores. After pushing the cast
+/// through `m`, we will get IR to this form:
+/// ```
+/// ptr : StorageType* = ...
+/// m_ptr = FieldAddr(ptr, member)
+/// m = CastStorageToLogicalDeref(m_ptr);
+/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]]
+/// store tmpVar, m
+/// call f, tmpVar
+/// ```
+/// This time, we will see that `m` is being used by a `store` into a `[[ImmutableTempVar]]` var,
+/// and we can safely replace all uses of `tmpVar` to `m_ptr`, and therefore the IR will become:
+/// ```
+/// ptr : StorageType* = ...
+/// m_ptr = FieldAddr(ptr, member)
+/// m = CastStorageToLogical(m_ptr);
+/// call f, m_ptr
+/// ```
+/// Now, we are in the case where a `CastStorageToLogical` is used as argument in a `call`.
+/// This will trigger our function specialization rule to create `f_1` that accepets a
+/// `StorageMember*`, and we will rewrite the IR again to:
+/// ```
+/// ptr : StorageType* = ...
+/// m_ptr = FieldAddr(ptr, member)
+/// call f_1, m_ptr
+/// ```
+///
+/// # Trailing Pointer Rewrite
+///
+/// Another transformation done in this pass is it also rewrites struct with unsized trailing
+/// arrays. Since an unsized type isn't a physical type and cannot be used as a pointee type,
+/// we will have problem translating the following code to SPIRV:
+/// ```
+/// struct Foo { int count; int[] values; }
+/// uniform Foo* b;
+/// ```
+///
+/// When we create a storage type for `Foo`, we will define it as:
+/// ```
+/// struct Foo_std430 { int count; }
+/// ```
+/// Where we removed the trailing array.
+/// This makes `Foo_std430` an ordinary sized type that can be used freely as pointee type
+/// in SPIRV.
+///
+/// However this does mean that we also need to translate things like `ptr->values[2]`
+/// into `((int*)(ptr+1))[2]`. Which we also handle during step 2 of the algorithm.
+/// (`maybeTranslateTrailingPointerGetElementAddress`)
+///
+
namespace Slang
{
@@ -85,7 +290,7 @@ struct LoweredElementTypeContext
else
{
auto val = builder.emitIntrinsicInst(
- tryGetPointedToType(&builder, dest->getDataType()),
+ tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
op,
1,
&operand);
@@ -145,6 +350,22 @@ struct LoweredElementTypeContext
TargetProgram* target;
BufferElementTypeLoweringOptions options;
+ struct SpecializationKey
+ {
+ IRFunc* callee;
+ IRFuncType* specializedFuncType;
+ bool operator==(const SpecializationKey& other) const
+ {
+ return (callee == other.callee && specializedFuncType == other.specializedFuncType);
+ }
+ HashCode64 getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(callee), Slang::getHashCode(specializedFuncType));
+ }
+ };
+ // Specialized functions that takes storage-typed pointers instead of logical-typed pointers.
+ Dictionary<SpecializationKey, IRFunc*> specializedFuncs;
+
LoweredElementTypeContext(
TargetProgram* target,
BufferElementTypeLoweringOptions inOptions,
@@ -881,17 +1102,25 @@ struct LoweredElementTypeContext
IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType)
{
- if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) ||
+ IRBuilder builder(newElementType);
+ builder.setInsertAfter(newElementType);
+ if (auto ptrType = as<IRPtrTypeBase>(originalPtrLikeType))
+ {
+ return builder.getPtrType(newElementType, ptrType);
+ }
+
+ if (as<IRPointerLikeType>(originalPtrLikeType) ||
as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType) ||
as<IRGLSLShaderStorageBufferType>(originalPtrLikeType))
{
- IRBuilder builder(newElementType);
- builder.setInsertAfter(newElementType);
ShortList<IRInst*> operands;
- for (UInt i = 0; i < originalPtrLikeType->getOperandCount(); i++)
+ operands.add(newElementType);
+ for (UInt i = 1; i < originalPtrLikeType->getOperandCount(); i++)
+ {
operands.add(originalPtrLikeType->getOperand(i));
- operands[0] = newElementType;
- return builder.getType(
+ }
+ return (IRType*)builder.emitIntrinsicInst(
+ builder.getTypeKind(),
originalPtrLikeType->getOp(),
(UInt)operands.getCount(),
operands.getArrayView().getBuffer());
@@ -914,26 +1143,544 @@ struct LoweredElementTypeContext
TypeLoweringConfig config;
};
- IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst)
+ IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst, IRInst* baseAddr)
{
switch (loadStoreInst->getOp())
{
case kIROp_Load:
case kIROp_Store:
- return loadStoreInst->getOperand(0);
+ return baseAddr;
case kIROp_StructuredBufferLoad:
case kIROp_StructuredBufferLoadStatus:
case kIROp_RWStructuredBufferLoad:
case kIROp_RWStructuredBufferLoadStatus:
case kIROp_RWStructuredBufferStore:
return builder.emitRWStructuredBufferGetElementPtr(
- loadStoreInst->getOperand(0),
+ baseAddr,
loadStoreInst->getOperand(1));
default:
return nullptr;
}
}
+ bool maybeTranslateTrailingPointerGetElementAddress(
+ IRBuilder& builder,
+ IRFieldAddress* fieldAddr,
+ IRCastStorageToLogicalBase* castInst,
+ TypeLoweringConfig& config,
+ List<IRCastStorageToLogicalBase*>& castInstWorkList)
+ {
+ // If we are accessing an unsized array element from a pointer, we need to
+ // compute
+ // the trailing ptr that points to the first element of the array.
+ // And then replace all getElementPtr(arrayPtr, index) with
+ // getOffsetPtr(trailingPtr, index).
+
+ auto ptrType = as<IRPtrTypeBase>(fieldAddr->getDataType());
+ if (!ptrType)
+ return false;
+ if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
+ return false;
+ if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
+ {
+ builder.setInsertBefore(fieldAddr);
+ auto newArrayPtrVal = fieldAddr->getBase();
+ auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), config);
+
+ IRSizeAndAlignment arrayElementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ loweredInnerType.loweredType,
+ &arrayElementSizeAlignment);
+ IRSizeAndAlignment baseSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()),
+ &baseSizeAlignment);
+
+ // Convert pointer to uint64 and adjust offset.
+ IRIntegerValue offset = baseSizeAlignment.size;
+ offset = align(offset, arrayElementSizeAlignment.alignment);
+ if (offset != 0)
+ {
+ auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
+ newArrayPtrVal = builder.emitAdd(
+ rawPtr->getFullType(),
+ rawPtr,
+ builder.getIntValue(builder.getUInt64Type(), offset));
+ }
+ newArrayPtrVal = builder.emitBitCast(
+ builder.getPtrType(loweredInnerType.loweredType, ptrType),
+ newArrayPtrVal);
+ traverseUses(
+ fieldAddr,
+ [&](IRUse* fieldAddrUse)
+ {
+ auto fieldAddrUser = fieldAddrUse->getUser();
+ if (fieldAddrUser->getOp() == kIROp_GetElementPtr)
+ {
+ builder.setInsertBefore(fieldAddrUser);
+ auto newElementPtr =
+ builder.emitGetOffsetPtr(newArrayPtrVal, fieldAddrUser->getOperand(1));
+ auto castedGEP = builder.emitCastStorageToLogical(
+ fieldAddrUser->getFullType(),
+ newElementPtr,
+ castInst->getBufferType());
+ fieldAddrUser->replaceUsesWith(castedGEP);
+ fieldAddrUser->removeAndDeallocate();
+ if (auto castStorage = as<IRCastStorageToLogicalBase>(castedGEP))
+ castInstWorkList.add(castStorage);
+ }
+ else if (fieldAddrUser->getOp() == kIROp_GetOffsetPtr)
+ {
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unknown use of pointer to unsized array.");
+ }
+ });
+ SLANG_ASSERT(!fieldAddr->hasUses());
+ fieldAddr->removeAndDeallocate();
+ return true;
+ }
+ return false;
+ }
+
+
+ // Helper function to discover all `call`s in `func` that has at least one argument
+ // that is `CastStorageToPhysical`.
+ void discoverCallsToProcess(List<IRCall*>& callWorkList, IRFunc* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ auto call = as<IRCall>(inst);
+ if (!call)
+ continue;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (arg->getOp() == kIROp_CastStorageToLogical)
+ {
+ callWorkList.add(call);
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ void deferStorageToLogicalCasts(
+ IRModule* module,
+ List<IRCastStorageToLogicalBase*> castInstWorkList)
+ {
+ IRBuilder builder(module);
+
+ while (castInstWorkList.getCount())
+ {
+ // We process call instructions after other instructions, so we
+ // can be sure that all castStorageToLogical insts have already
+ // been pushed to the call argument lists before we process it.
+ HashSet<IRCall*> callWorkListSet;
+ // Defer the storage-to-logical cast operation to latest possible time to avoid
+ // unnecessary packing/unpacking.
+ for (Index i = 0; i < castInstWorkList.getCount(); i++)
+ {
+ auto castInst = castInstWorkList[i];
+ auto ptrVal = castInst->getOperand(0);
+ auto config =
+ getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType());
+ traverseUses(
+ castInst,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_FieldAddress:
+ if (!isUseBaseAddrOperand(use, user))
+ break;
+ // If our logical struct type ends with an unsized array field, the
+ // storage struct type won't have this field defined.
+ // Therefore, all fieldAddress(obj, lastField) inst retrieving the last
+ // field of such struct should be translated into
+ // `(ArrayElementType*)((StorageStruct*)(obj)+1) + idx`.
+ // That is, we should first compute the tailing pointer of the
+ // struct, and replace all getElementPtr(fieldAddr, idx) with
+ // getOffsetPtr(tailingPtr, idx).
+ if (maybeTranslateTrailingPointerGetElementAddress(
+ builder,
+ (IRFieldAddress*)user,
+ castInst,
+ config,
+ castInstWorkList))
+ return;
+ [[fallthrough]];
+ case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
+ case kIROp_RWStructuredBufferGetElementPtr:
+ {
+ // gep(castStorageToLogical(x)) ==> castStorageToLogical(gep(x))
+ if (!isUseBaseAddrOperand(use, user))
+ break;
+ auto logicalBaseType = castInst->getDataType();
+ auto logicalType = user->getDataType();
+ IRInst* storageBaseAddr = ptrVal;
+ auto originalBaseValueType =
+ tryGetPointedToOrBufferElementType(&builder, logicalBaseType);
+ if (user->getOp() == kIROp_GetElementPtr)
+ {
+ // If original type is an array, the lowered type will be a
+ // struct. In that case, all existing address insts should be
+ // appended with a field extract.
+ if (as<IRArrayType>(originalBaseValueType))
+ {
+ auto arrayLowerInfo =
+ getLoweredTypeInfo(originalBaseValueType, config);
+ if (arrayLowerInfo.loweredInnerArrayType)
+ {
+ builder.setInsertBefore(user);
+ List<IRInst*> args;
+ for (UInt i = 0; i < user->getOperandCount(); i++)
+ args.add(user->getOperand(i));
+ storageBaseAddr = builder.emitFieldAddress(
+ builder.getPtrType(
+ arrayLowerInfo.loweredInnerArrayType),
+ ptrVal,
+ arrayLowerInfo.loweredInnerStructKey);
+ }
+ }
+ if (as<IRMatrixType>(originalBaseValueType))
+ {
+ // We are tring to get a pointer to a lowered matrix
+ // element. We process this insts at a later phase.
+ SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
+ lowerMatrixAddresses(
+ module,
+ MatrixAddrWorkItem{user, config});
+ break;
+ }
+ }
+
+
+ builder.setInsertBefore(user);
+ IRInst* storageGEP = nullptr;
+ switch (user->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ // For standard gep instructions, use the
+ // IR builder to auto-deduce result type
+ // of the new GEP inst.
+ ShortList<IRInst*> newArgs;
+ for (UInt i = 1; i < user->getOperandCount(); i++)
+ newArgs.add(user->getOperand(i));
+ storageGEP = builder.emitElementAddress(
+ storageBaseAddr,
+ newArgs.getArrayView().arrayView);
+ break;
+ }
+ default:
+ {
+ // For non-standard gep instructions, e.g.
+ // RWStructuredBufferGetElementPtr,
+ // manually create the inst here.
+ ShortList<IRInst*> newArgs;
+ newArgs.add(storageBaseAddr);
+ for (UInt i = 1; i < user->getOperandCount(); i++)
+ newArgs.add(user->getOperand(i));
+ auto logicalValueType = tryGetPointedToOrBufferElementType(
+ &builder,
+ logicalType);
+ auto storageTypeInfo =
+ getLoweredTypeInfo(logicalValueType, config);
+ storageGEP = builder.emitIntrinsicInst(
+ builder.getPtrType(storageTypeInfo.loweredType),
+ user->getOp(),
+ newArgs.getCount(),
+ newArgs.getArrayView().getBuffer());
+ break;
+ }
+ }
+ auto castOfGEP = builder.emitCastStorageToLogical(
+ logicalType,
+ storageGEP,
+ castInst->getBufferType());
+ user->replaceUsesWith(castOfGEP);
+ user->removeAndDeallocate();
+ if (auto castStorage = as<IRCastStorageToLogical>(castOfGEP))
+ castInstWorkList.add(castStorage);
+ break;
+ }
+ case kIROp_Call:
+ {
+ // call(f, castStorageToLogical(x)) ==> call(f', x)
+ //
+ // If we see a call that takes a logical typed pointer, we will
+ // specialize the callee to take a storage typed pointer instead,
+ // and push the cast to inside the callee.
+ // We will process calls after other gep insts, so for now just add
+ // it into a separate worklist.
+ if (castInst->getOp() == kIROp_CastStorageToLogical)
+ {
+ callWorkListSet.add((IRCall*)user);
+ }
+ break;
+ }
+ case kIROp_Load:
+ case kIROp_StructuredBufferLoad:
+ case kIROp_RWStructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ case kIROp_RWStructuredBufferLoadStatus:
+ case kIROp_StructuredBufferConsume:
+ {
+ // If we see a load(CastStorageToLogical(storageAddr)),
+ // then based on what `storageAddr` is, we will push down
+ // the cast differently.
+ // - If `storageAddr` is already a tempVar that we introduced to
+ // hold the value of a buffer resource load, we can simply
+ // convert this into `CastStorageToLogicalDeref(storageAddr)`.
+ // - Otherwise, if `storageAddr` is a buffer location, we will
+ // create a temp var to hold the result of the memory load,
+ // Then we create a `CastStorageToLogicalDeref(tempVar)`
+ // structure and use it to replace `user`.
+ // Note that it is important to introduce a temp var and preserve
+ // the buffer load operation, so we are not changing the memory
+ // semantics of the original program.
+ if (!isUseBaseAddrOperand(use, user))
+ break;
+ // If loaded value is itself a pointer or buffer,
+ // stop pushing the cast along the resulting address.
+ // we will handle loads from the pointer separately.
+ if (as<IRPointerLikeType>(user->getDataType()) ||
+ as<IRPtrTypeBase>(user->getDataType()) ||
+ as<IRHLSLStructuredBufferTypeBase>(user->getDataType()))
+ break;
+ // Don't push the cast beyond the load if we are already
+ // a simple type.
+ if (!isCompositeType(user->getDataType()))
+ break;
+ builder.setInsertBefore(user);
+ IRCloneEnv cloneEnv;
+ auto newLoad = cloneInst(&cloneEnv, &builder, user);
+ newLoad->setOperand(0, ptrVal);
+ auto elementStorageType = tryGetPointedToOrBufferElementType(
+ &builder,
+ ptrVal->getDataType());
+ newLoad->setFullType(elementStorageType);
+ IRInst* tempVar = nullptr;
+ if (as<IRLoad>(user))
+ {
+ auto rootAddr = getRootAddr(ptrVal);
+ if (rootAddr->findDecorationImpl(
+ kIROp_TempCallArgImmutableVarDecoration))
+ tempVar = ptrVal;
+ }
+ if (!tempVar)
+ {
+ tempVar = builder.emitVar(elementStorageType);
+ builder.addDecoration(
+ tempVar,
+ kIROp_TempCallArgImmutableVarDecoration);
+ builder.emitStore(tempVar, newLoad);
+ }
+ auto newCast = builder.emitCastStorageToLogicalDeref(
+ user->getFullType(),
+ tempVar,
+ castInst->getBufferType());
+ user->replaceUsesWith(newCast);
+ user->removeAndDeallocate();
+ castInstWorkList.add(newCast);
+ break;
+ }
+ case kIROp_FieldExtract:
+ case kIROp_GetElement:
+ {
+ if (!isUseBaseAddrOperand(use, user))
+ break;
+ // elementExtract(castStorageToLogicalDeref(addr), key)
+ // ==> load(gep(castStorageToLogical(addr), key)
+ builder.setInsertBefore(user);
+ auto castAddr = builder.emitCastStorageToLogical(
+ builder.getPtrType(castInst->getDataType()),
+ ptrVal,
+ castInst->getBufferType());
+ IRInst* gep = nullptr;
+ if (user->getOp() == kIROp_GetElement)
+ gep = builder.emitElementAddress(castAddr, user->getOperand(1));
+ else
+ gep = builder.emitFieldAddress(castAddr, user->getOperand(1));
+ auto load = builder.emitLoad(gep);
+ user->replaceUsesWith(load);
+ user->removeAndDeallocate();
+ if (auto castStorage = as<IRCastStorageToLogical>(castAddr))
+ castInstWorkList.add(castStorage);
+ break;
+ }
+ case kIROp_Store:
+ {
+ // If we see `store(tempVar, castStorageToLogicalDeref(addr))`,
+ // replace `tempVar` with `castStorageToLogical(addr)`.
+ if (castInst->getOp() != kIROp_CastStorageToLogicalDeref)
+ break;
+ auto store = as<IRStore>(user);
+ if (store->getVal() != castInst)
+ break;
+ auto dest = store->getPtr();
+ if (!dest->findDecorationImpl(
+ kIROp_TempCallArgImmutableVarDecoration))
+ break;
+ builder.setInsertBefore(user);
+ auto castAddr = builder.emitCastStorageToLogical(
+ builder.getPtrType(castInst->getDataType()),
+ ptrVal,
+ castInst->getBufferType());
+ dest->replaceUsesWith(castAddr);
+ dest->removeAndDeallocate();
+ if (auto castStorage = as<IRCastStorageToLogical>(castAddr))
+ castInstWorkList.add(castStorage);
+ break;
+ }
+ }
+ });
+ }
+
+ // Now that we have processed all GEP instructions, we can now proceed to
+ // process all calls. This is done by making a clone of the callee, and change
+ // the parameter type from logical type to storage type, and insert a
+ // castStorageToLogical on the parameter. Then we go back to the beginning and make sure
+ // we process those newly created castStorageToLogical insts.
+ List<IRCastStorageToLogicalBase*> newCasts;
+ List<IRCall*> callWorkList;
+ for (auto call : callWorkListSet)
+ callWorkList.add(call);
+ for (Index c = 0; c < callWorkList.getCount(); c++)
+ {
+ auto call = callWorkList[c];
+ auto calleeFunc = as<IRGlobalValueWithParams>(call->getCallee());
+ // We compute the func type for the specialized func based on the arguments
+ // provided, and check the specialization cache to reuse existing specialization
+ // when possible.
+ List<IRInst*> oldParams;
+ for (auto param : calleeFunc->getParams())
+ oldParams.add(param);
+ SLANG_ASSERT(oldParams.getCount() == (Index)call->getArgCount());
+
+ ShortList<IRType*> paramTypes;
+ ShortList<IRInst*> newArgs;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (auto castArg = as<IRCastStorageToLogical>(arg))
+ {
+ auto oldParamPtrType = oldParams[i]->getDataType();
+ auto storageValueType = tryGetPointedToOrBufferElementType(
+ &builder,
+ castArg->getOperand(0)->getDataType());
+ auto storagePtrType =
+ getLoweredPtrLikeType(oldParamPtrType, storageValueType);
+ paramTypes.add(storagePtrType);
+ newArgs.add(castArg->getOperand(0));
+ }
+ else
+ {
+ paramTypes.add(arg->getDataType());
+ newArgs.add(arg);
+ }
+ }
+ auto specializedFuncType = builder.getFuncType(
+ (UInt)paramTypes.getCount(),
+ paramTypes.getArrayView().getBuffer(),
+ call->getDataType());
+ auto key = SpecializationKey{(IRFunc*)calleeFunc, specializedFuncType};
+ IRFunc* specializedFunc = nullptr;
+ if (!specializedFuncs.tryGetValue(key, specializedFunc))
+ {
+ specializedFunc = createSpecializedFuncThatUseStorageType(
+ call,
+ specializedFuncType,
+ newCasts);
+ specializedFuncs[key] = specializedFunc;
+
+ // The cloned function may also contain `call`s with
+ // `CastStorageToLogical` arguments, and we want to add
+ // thoses calls to the callWorkList for further processing.
+ discoverCallsToProcess(callWorkList, specializedFunc);
+ }
+ builder.setInsertBefore(call);
+ auto newCall = builder.emitCallInst(
+ call->getFullType(),
+ specializedFunc,
+ newArgs.getArrayView().arrayView);
+ call->replaceUsesWith(newCall);
+ call->removeAndDeallocate();
+ }
+
+ // Remove any casts that have no more uses.
+ for (auto cast : castInstWorkList)
+ {
+ if (!cast->hasUses())
+ cast->removeAndDeallocate();
+ }
+
+ // Continue to process new casts added during function specialization.
+ castInstWorkList.swapWith(newCasts);
+ }
+ }
+
+ IRFunc* createSpecializedFuncThatUseStorageType(
+ IRCall* call,
+ IRFuncType* specializedFuncType,
+ List<IRCastStorageToLogicalBase*>& outNewCasts)
+ {
+ IRBuilder builder(call);
+ builder.setInsertBefore(call->getCallee());
+
+ // Create a clone of the callee.
+ IRCloneEnv cloneEnv;
+ auto clonedFunc = as<IRFunc>(cloneInst(&cloneEnv, &builder, call->getCallee()));
+ List<IRUse*> uses;
+
+ // If a parameter is being translated to storage type,
+ // insert a cast to convert it to logical type.
+ List<IRParam*> params;
+ for (auto param : clonedFunc->getParams())
+ params.add(param);
+ for (UInt i = 0; i < (UInt)params.getCount(); i++)
+ {
+ auto param = params[i];
+ SLANG_RELEASE_ASSERT(i < call->getArgCount());
+ auto arg = call->getArg(i);
+ auto cast = as<IRCastStorageToLogical>(arg);
+ if (!cast)
+ continue;
+ auto logicalParamType = param->getFullType();
+ auto storageType = specializedFuncType->getParamType(i);
+ param->setFullType((IRType*)storageType);
+ setInsertAfterOrdinaryInst(&builder, param);
+
+ // Store uses of param before creating a cast inst that uses it.
+ uses.clear();
+ for (auto use = param->firstUse; use; use = use->nextUse)
+ uses.add(use);
+ auto castedParam =
+ builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType());
+ if (auto castStorage = as<IRCastStorageToLogicalBase>(castedParam))
+ outNewCasts.add(castStorage);
+
+ // Replace all previous uses of param to use castedParam instead.
+ for (auto use : uses)
+ builder.replaceOperand(use, castedParam);
+ }
+ clonedFunc->setFullType(specializedFuncType);
+ removeLinkageDecorations(clonedFunc);
+ return clonedFunc;
+ }
+
void processModule(IRModule* module)
{
IRBuilder builder(module);
@@ -941,6 +1688,7 @@ struct LoweredElementTypeContext
{
IRType* bufferType;
IRType* elementType;
+ IRType* loweredBufferType = nullptr;
bool shouldWrapArrayInStruct = false;
};
List<BufferTypeInfo> bufferTypeInsts;
@@ -990,12 +1738,10 @@ struct LoweredElementTypeContext
bufferTypeInsts.add(BufferTypeInfo{(IRType*)globalInst, elementType});
}
- // Maintain a pending work list of all matrix addresses, and try to lower them out of
- // existance after everything else has been lowered.
- List<MatrixAddrWorkItem> matrixAddrInsts;
+ List<IRCastStorageToLogicalBase*> castInstWorkList;
- for (auto bufferTypeInfo : bufferTypeInsts)
+ for (auto& bufferTypeInfo : bufferTypeInsts)
{
auto bufferType = bufferTypeInfo.bufferType;
auto elementType = bufferTypeInfo.elementType;
@@ -1022,10 +1768,10 @@ struct LoweredElementTypeContext
(UInt)typeOperands.getCount(),
typeOperands.getArrayView().getBuffer());
- // We treat a value of a buffer type as a pointer, and use a work list to translate
- // all loads and stores through the pointer values that needs lowering.
+ // Replace all global buffer declarations to use the storage type instead,
+ // and insert initial `castStorageToLogical` instructions to convert the
+ // storage-typed pointer to logical-typed pointer.
- List<IRInst*> ptrValsWorkList;
traverseUses(
bufferType,
[&](IRUse* use)
@@ -1033,433 +1779,400 @@ struct LoweredElementTypeContext
auto user = use->getUser();
if (use != &user->typeUse)
return;
- ptrValsWorkList.add(use->getUser());
+ // We don't want to insert cast instructions for uses of
+ // intermediate address instruction that are themselves
+ // derived from some other base address. We will let
+ // the later part of the pass to systematically propagate
+ // the cast through them.
+ switch (user->getOp())
+ {
+ case kIROp_FieldAddress:
+ case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
+ case kIROp_RWStructuredBufferGetElementPtr:
+ return;
+ }
+ auto ptrVal = use->getUser();
+ setInsertAfterOrdinaryInst(&builder, ptrVal);
+ builder.replaceOperand(use, loweredBufferType);
+ auto logicalBufferType = getLoweredPtrLikeType(bufferType, elementType);
+ auto castStorageToLogical =
+ builder.emitCastStorageToLogical(logicalBufferType, ptrVal, bufferType);
+ traverseUses(
+ ptrVal,
+ [&](IRUse* ptrUse)
+ {
+ if (ptrUse->getUser() != castStorageToLogical)
+ builder.replaceOperand(ptrUse, castStorageToLogical);
+ });
+ if (auto castStorage = as<IRCastStorageToLogical>(castStorageToLogical))
+ castInstWorkList.add(castStorage);
});
+ bufferTypeInfo.loweredBufferType = loweredBufferType;
+ }
+
+ // Push down `CastStorageToLogical` insts we inserted above to latest possible locations,
+ // specializing all function calls along the way, until we truly need the the logical value.
+ // This means that `FieldAddr(CastStorageToLogical(buffer), field0))` is translated to
+ // `CastStorageToLogical(FieldAddr(buffer, field0))`. This way we can be sure that we are
+ // doing minimal packing/unpacking.
+ deferStorageToLogicalCasts(module, _Move(castInstWorkList));
+
+ // Now translate the `CastStorageToLogical` into actual packing/unpacking code.
+ materializeStorageToLogicalCasts(module->getModuleInst());
+
+ // Replace all remaining uses of bufferType to loweredBufferType, these uses are
+ // non-operational and should be directly replaceable, such as uses in `IRFuncType`.
+ for (auto bufferTypeInst : bufferTypeInsts)
+ {
+ if (!bufferTypeInst.loweredBufferType)
+ continue;
+ bufferTypeInst.bufferType->replaceUsesWith(bufferTypeInst.loweredBufferType);
+ bufferTypeInst.bufferType->removeAndDeallocate();
+ }
+ }
- // Translate the values to use new lowered buffer type instead.
- for (Index i = 0; i < ptrValsWorkList.getCount(); i++)
+ void materializeStorageToLogicalCastsImpl(IRCastStorageToLogicalBase* castInst)
+ {
+ IRBuilder builder(castInst);
+ if (!castInst->hasUses())
+ {
+ castInst->removeAndDeallocate();
+ return;
+ }
+ if (castInst->getOp() == kIROp_CastStorageToLogicalDeref)
+ {
+ // Convert CastStorageToLogicalDeref to load(CastStorageToLogical) to reuse
+ // the same materialization logic for CastStorageToLogical.
+ //
+ builder.setInsertBefore(castInst);
+ auto ptrType = builder.getPtrType(castInst->getDataType());
+ auto castPtr = builder.emitCastStorageToLogical(
+ (IRType*)ptrType,
+ castInst->getVal(),
+ castInst->getBufferType());
+ auto load = builder.emitLoad(castPtr);
+ castInst->replaceUsesWith(load);
+ castInst->removeAndDeallocate();
+ if (auto castStorage = as<IRCastStorageToLogical>(castPtr))
+ materializeStorageToLogicalCastsImpl(castStorage);
+ return;
+ }
+
+ // Translate the values to use new lowered buffer type instead.
+
+ auto ptrVal = castInst->getOperand(0);
+ auto oldPtrType = castInst->getFullType();
+ auto originalElementType = oldPtrType->getOperand(0);
+ auto config = getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType());
+
+
+ LoweredElementTypeInfo loweredElementTypeInfo = {};
+ if (auto getElementPtr = as<IRGetElementPtr>(ptrVal))
+ {
+ if (auto arrayType = as<IRArrayTypeBase>(tryGetPointedToOrBufferElementType(
+ &builder,
+ getElementPtr->getBase()->getDataType())))
{
- auto ptrVal = ptrValsWorkList[i];
- auto oldPtrType = ptrVal->getFullType();
- auto originalElementType = oldPtrType->getOperand(0);
-
- // If we are accessing an unsized array element from a pointer, we need to compute
- // the trailing ptr that points to the first element of the array.
- // And then replace all getElementPtr(arrayPtr, index) with
- // getOffsetPtr(trailingPtr, index).
- if (auto fieldAddr = as<IRFieldAddress>(ptrVal))
+ // For WGSL, an array of scalar or vector type will always be converted to
+ // an array of 16-byte aligned vector type. In this case, we will run into a
+ // GetElementPtr where the result type is different from the element type of
+ // the base array.
+ // We should setup loweredElementTypeInfo so the remaining logic can handle
+ // this case and insert proper packing/unpacking logic around it.
+ if (arrayType->getElementType() != originalElementType &&
+ isScalarOrVectorType(originalElementType))
{
- auto handleUnsizedArrayAccess = [&]() -> bool
- {
- auto ptrType = as<IRPtrType>(ptrVal->getDataType());
- if (!ptrType)
- return false;
- if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
- return false;
- if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
- {
- builder.setInsertBefore(ptrVal);
- auto newArrayPtrVal = fieldAddr->getBase();
- auto loweredInnerType =
- getLoweredTypeInfo(unsizedArrayType->getElementType(), config);
+ loweredElementTypeInfo.loweredType = arrayType->getElementType();
+ loweredElementTypeInfo.originalType = (IRType*)originalElementType;
+ loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod(
+ loweredElementTypeInfo.originalType,
+ loweredElementTypeInfo.loweredType);
+ loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod(
+ loweredElementTypeInfo.loweredType,
+ loweredElementTypeInfo.originalType);
+ }
+ }
+ }
- IRSizeAndAlignment arrayElementSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- loweredInnerType.loweredType,
- &arrayElementSizeAlignment);
- IRSizeAndAlignment baseSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
- &baseSizeAlignment);
+ // For general cases we simply check if the element type needs lowering.
+ // If so we will insert packing/unpacking logic if necessary.
+ //
+ if (!loweredElementTypeInfo.loweredType)
+ {
+ loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, config);
+ }
- // Convert pointer to uint64 and adjust offset.
- IRIntegerValue offset = baseSizeAlignment.size;
- offset = align(offset, arrayElementSizeAlignment.alignment);
- if (offset != 0)
- {
- auto rawPtr =
- builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
- newArrayPtrVal = builder.emitAdd(
- rawPtr->getFullType(),
- rawPtr,
- builder.getIntValue(builder.getUInt64Type(), offset));
- }
- newArrayPtrVal = builder.emitBitCast(
- builder.getPtrType(
- loweredInnerType.loweredType,
- ptrType->getAddressSpace()),
- newArrayPtrVal);
- traverseUses(
- ptrVal,
- [&](IRUse* use)
- {
- auto user = use->getUser();
- if (user->getOp() == kIROp_GetElementPtr)
- {
- builder.setInsertBefore(user);
- auto newElementPtr = builder.emitGetOffsetPtr(
- newArrayPtrVal,
- user->getOperand(1));
- user->replaceUsesWith(newElementPtr);
- user->removeAndDeallocate();
- ptrValsWorkList.add(newElementPtr);
- }
- else if (user->getOp() == kIROp_GetOffsetPtr)
- {
- }
- else
- {
- SLANG_UNEXPECTED(
- "unknown use of pointer to unsized array.");
- }
- });
- SLANG_ASSERT(!ptrVal->hasUses());
- ptrVal->removeAndDeallocate();
- return true;
- }
- return false;
- };
- if (handleUnsizedArrayAccess())
- continue;
- }
+ if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType)
+ {
+ castInst->replaceUsesWith(ptrVal);
+ castInst->removeAndDeallocate();
+ return;
+ }
- LoweredElementTypeInfo loweredElementTypeInfo = {};
- if (auto getElementPtr = as<IRGetElementPtr>(ptrVal))
+ traverseUses(
+ castInst,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (as<IRDecoration>(user))
+ return;
+ switch (user->getOp())
{
- if (auto arrayType = as<IRArrayTypeBase>(
- tryGetPointedToType(&builder, getElementPtr->getBase()->getDataType())))
+ case kIROp_Load:
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ case kIROp_RWStructuredBufferLoad:
+ case kIROp_RWStructuredBufferLoadStatus:
+ case kIROp_StructuredBufferConsume:
{
- // For WGSL, an array of scalar or vector type will always be converted to
- // an array of 16-byte aligned vector type. In this case, we will run into a
- // GetElementPtr where the result type is different from the element type of
- // the base array.
- // We should setup loweredElementTypeInfo so the remaining logic can handle
- // this case and insert proper packing/unpacking logic around it.
- if (arrayType->getElementType() != originalElementType &&
- isScalarOrVectorType(originalElementType))
+ if (castInst != user->getOperand(0))
+ break;
+ builder.setInsertBefore(user);
+ auto addr = getBufferAddr(builder, user, ptrVal);
+ if (!addr)
+ {
+ IRCloneEnv cloneEnv = {};
+ builder.setInsertBefore(user);
+ auto newLoad = cloneInst(&cloneEnv, &builder, user);
+ newLoad->setFullType(loweredElementTypeInfo.loweredType);
+ addr = builder.emitVar(loweredElementTypeInfo.loweredType);
+ builder.emitStore(addr, newLoad);
+ }
+ if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
{
- loweredElementTypeInfo.loweredType = arrayType->getElementType();
- loweredElementTypeInfo.originalType = (IRType*)originalElementType;
- loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod(
- loweredElementTypeInfo.originalType,
- loweredElementTypeInfo.loweredType);
- loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod(
- loweredElementTypeInfo.loweredType,
- loweredElementTypeInfo.originalType);
+ builder.addAlignedAddressDecoration(addr, alignedAttr->getAlignment());
}
+ auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply(
+ builder,
+ loweredElementTypeInfo.originalType,
+ addr);
+ user->replaceUsesWith(unpackedVal);
+ user->removeAndDeallocate();
+ return;
}
- }
-
- // For general cases we simply check if the element type needs lowering.
- // If so we will insert packing/unpacking logic if necessary.
- //
- if (!loweredElementTypeInfo.loweredType)
- {
- loweredElementTypeInfo =
- getLoweredTypeInfo((IRType*)originalElementType, config);
- }
-
- if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType)
- continue;
-
- ptrVal->setFullType(getLoweredPtrLikeType(
- ptrVal->getFullType(),
- loweredElementTypeInfo.loweredType));
-
- traverseUses(
- ptrVal,
- [&](IRUse* use)
+ case kIROp_Store:
+ case kIROp_RWStructuredBufferStore:
+ case kIROp_StructuredBufferAppend:
{
- auto user = use->getUser();
- if (as<IRDecoration>(user))
- return;
- switch (user->getOp())
+ // Use must be the dest operand of the store inst.
+ if (use != user->getOperands() + 0)
+ break;
+ IRCloneEnv cloneEnv = {};
+ builder.setInsertBefore(user);
+ auto originalVal = getStoreVal(user);
+ if (auto sbAppend = as<IRStructuredBufferAppend>(user))
{
- case kIROp_Load:
- case kIROp_StructuredBufferLoad:
- case kIROp_StructuredBufferLoadStatus:
- case kIROp_RWStructuredBufferLoad:
- case kIROp_RWStructuredBufferLoadStatus:
- case kIROp_StructuredBufferConsume:
+ builder.setInsertBefore(sbAppend);
+ IRInst* addr = nullptr;
+ if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref)
{
- builder.setInsertBefore(user);
- auto addr = getBufferAddr(builder, user);
- if (!addr)
- {
- IRCloneEnv cloneEnv = {};
- builder.setInsertBefore(user);
- auto newLoad = cloneInst(&cloneEnv, &builder, user);
- newLoad->setFullType(loweredElementTypeInfo.loweredType);
- addr = builder.emitVar(loweredElementTypeInfo.loweredType);
- builder.emitStore(addr, newLoad);
- }
- if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
- {
- builder.addAlignedAddressDecoration(
- addr,
- alignedAttr->getAlignment());
- }
- auto unpackedVal =
- loweredElementTypeInfo.convertLoweredToOriginal.apply(
- builder,
- loweredElementTypeInfo.originalType,
- addr);
- user->replaceUsesWith(unpackedVal);
- user->removeAndDeallocate();
- break;
+ addr = originalVal->getOperand(0);
}
- case kIROp_Store:
- case kIROp_RWStructuredBufferStore:
- case kIROp_StructuredBufferAppend:
+ else
{
- // Use must be the dest operand of the store inst.
- if (use != user->getOperands() + 0)
- break;
- IRCloneEnv cloneEnv = {};
- builder.setInsertBefore(user);
- auto originalVal = getStoreVal(user);
- IRInst* addr = getBufferAddr(builder, user);
- if (addr)
- {
- if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
- {
- builder.addAlignedAddressDecoration(
- addr,
- alignedAttr->getAlignment());
- }
-
- loweredElementTypeInfo.convertOriginalToLowered
- .applyDestinationDriven(builder, addr, originalVal);
- user->removeAndDeallocate();
- }
- else if (auto sbAppend = as<IRStructuredBufferAppend>(user))
- {
- builder.setInsertBefore(sbAppend);
- addr = builder.emitVar(loweredElementTypeInfo.loweredType);
- loweredElementTypeInfo.convertOriginalToLowered
- .applyDestinationDriven(builder, addr, originalVal);
- auto packedVal = builder.emitLoad(addr);
- sbAppend->setOperand(1, packedVal);
- }
- else
- {
- SLANG_UNREACHABLE("unhandled store type");
- }
- break;
+ addr = builder.emitVar(loweredElementTypeInfo.loweredType);
+ loweredElementTypeInfo.convertOriginalToLowered
+ .applyDestinationDriven(builder, addr, originalVal);
}
- case kIROp_GetElementPtr:
- case kIROp_FieldAddress:
+ auto packedVal = builder.emitLoad(addr);
+ sbAppend->setOperand(1, packedVal);
+ }
+ else
+ {
+ IRInst* addr = getBufferAddr(builder, user, ptrVal);
+ if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
{
- // If original type is an array, the lowered type will be a struct.
- // In that case, all existing address insts should be appended with
- // a field extract.
- if (as<IRArrayType>(originalElementType))
- {
- builder.setInsertBefore(user);
- List<IRInst*> args;
- for (UInt i = 0; i < user->getOperandCount(); i++)
- args.add(user->getOperand(i));
- auto newArrayPtrVal = builder.emitFieldAddress(
- builder.getPtrType(
- loweredElementTypeInfo.loweredInnerArrayType),
- ptrVal,
- loweredElementTypeInfo.loweredInnerStructKey);
- builder.replaceOperand(use, newArrayPtrVal);
- ptrValsWorkList.add(user);
- }
- else if (as<IRMatrixType>(originalElementType))
- {
- // We are tring to get a pointer to a lowered matrix element.
- // We process this insts at a later phase.
- SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
- matrixAddrInsts.add(MatrixAddrWorkItem{user, config});
- }
- else
- {
- // If we getting a derived address from the pointer, we need
- // to recursively lower the new address. We do so by pushing
- // the address inst into the work list.
- ptrValsWorkList.add(user);
- }
+ builder.addAlignedAddressDecoration(
+ addr,
+ alignedAttr->getAlignment());
}
- break;
- case kIROp_RWStructuredBufferGetElementPtr:
- case kIROp_GetOffsetPtr:
- ptrValsWorkList.add(user);
- break;
- case kIROp_StructuredBufferGetDimensions:
- break;
- case kIROp_Call:
+ if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref)
+ {
+ auto valAddr = originalVal->getOperand(0);
+ auto storageVal = builder.emitLoad(valAddr);
+ builder.emitStore(addr, storageVal);
+ }
+ else
{
- // If a structured buffer or pointer typed value is used directly as
- // an argument, we don't need to do any marshalling here.
- if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType()))
- break;
- if (options.lowerBufferPointer &&
- as<IRPtrType>(ptrVal->getDataType()))
- break;
- // If we are calling a function with an l-value pointer from buffer
- // access, we need to materialize the object as a local variable,
- // and pass the address of the local variable to the function.
- builder.setInsertBefore(user);
- auto unpackedVal =
- loweredElementTypeInfo.convertLoweredToOriginal.apply(
- builder,
- (IRType*)originalElementType,
- ptrVal);
- auto var = builder.emitVar((IRType*)originalElementType);
- builder.emitStore(var, unpackedVal);
- use->set(var);
- builder.setInsertAfter(user);
- auto newVal = builder.emitLoad(var);
loweredElementTypeInfo.convertOriginalToLowered
- .applyDestinationDriven(builder, ptrVal, newVal);
+ .applyDestinationDriven(builder, addr, originalVal);
}
- break;
- default:
- break;
+ user->removeAndDeallocate();
}
- });
- }
+ return;
+ }
+ default:
+ break;
+ }
+ // If the pointer is used in any other way that we don't recognize,
+ // preserve it as is without translation.
+ builder.setInsertBefore(user);
+ builder.replaceOperand(use, ptrVal);
+ });
+
+ if (!castInst->hasUses())
+ castInst->removeAndDeallocate();
+ }
- // Replace all remaining uses of bufferType to loweredBufferType, these uses are
- // non-operational and should be directly replaceable, such as uses in `IRFuncType`.
- bufferType->replaceUsesWith(loweredBufferType);
- bufferType->removeAndDeallocate();
+ void collectInstsOfType(List<IRCastStorageToLogicalBase*>& insts, IRInst* root, IROp op)
+ {
+ if (root->getOp() == op)
+ {
+ insts.add((IRCastStorageToLogicalBase*)root);
+ return;
+ }
+ for (auto child : root->getChildren())
+ {
+ collectInstsOfType(insts, child, op);
}
+ }
- // Process all matrix address uses.
- lowerMatrixAddresses(module, matrixAddrInsts);
+ void materializeStorageToLogicalCasts(IRInst* root)
+ {
+ // We will process all CastStorageToLogical insts first, before
+ // processing all CastStorageToLogicalDeref.
+ // This is because when we materialize a
+ // `store(CastStorageToLogical(addr), CastStorageToLogicalDeref(src))`,
+ // we can just fold out CastStorageToLogicalDeref and emit
+ // `store(addr, load(src))` instead.
+ // If we materialized `CastStorageToLogicalDeref` first we will
+ // miss this opportunity and generate more bloated code.
+ //
+ List<IRCastStorageToLogicalBase*> castInsts;
+ collectInstsOfType(castInsts, root, kIROp_CastStorageToLogical);
+ for (auto inst : castInsts)
+ materializeStorageToLogicalCastsImpl(inst);
+
+ castInsts.clear();
+ collectInstsOfType(castInsts, root, kIROp_CastStorageToLogicalDeref);
+ for (auto inst : castInsts)
+ materializeStorageToLogicalCastsImpl(inst);
}
// Lower all getElementPtr insts of a lowered matrix out of existance.
- void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts)
+ void lowerMatrixAddresses(IRModule* module, MatrixAddrWorkItem workItem)
{
IRBuilder builder(module);
- for (auto workItem : matrixAddrInsts)
- {
- auto majorAddr = workItem.matrixAddrInst;
- auto majorGEP = as<IRGetElementPtr>(majorAddr);
- SLANG_ASSERT(majorGEP);
- auto loweredMatrixType =
- cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType();
- auto matrixTypeInfo = getTypeLoweringMap(workItem.config)
- .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
- SLANG_ASSERT(matrixTypeInfo);
- auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
- auto rowCount = getIntVal(matrixType->getRowCount());
- traverseUses(
- majorAddr,
- [&](IRUse* use)
+ auto majorAddr = workItem.matrixAddrInst;
+ auto majorGEP = as<IRGetElementPtr>(majorAddr);
+ SLANG_ASSERT(majorGEP);
+ auto baseCast = as<IRCastStorageToLogical>(majorGEP->getBase());
+ SLANG_ASSERT(baseCast);
+ auto storageBase = baseCast->getOperand(0);
+ auto loweredMatrixType = cast<IRPtrTypeBase>(storageBase->getFullType())->getValueType();
+ auto matrixTypeInfo =
+ getTypeLoweringMap(workItem.config).mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
+ SLANG_ASSERT(matrixTypeInfo);
+ if (matrixTypeInfo->loweredType == matrixTypeInfo->originalType)
+ return;
+ auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ traverseUses(
+ majorAddr,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ builder.setInsertBefore(user);
+ switch (user->getOp())
{
- auto user = use->getUser();
- builder.setInsertBefore(user);
- switch (user->getOp())
+ case kIROp_Load:
{
- case kIROp_Load:
+ IRInst* resultInst = nullptr;
+ auto dataPtr = builder.emitFieldAddress(
+ getLoweredPtrLikeType(
+ majorAddr->getDataType(),
+ matrixTypeInfo->loweredInnerArrayType),
+ storageBase,
+ matrixTypeInfo->loweredInnerStructKey);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
{
- IRInst* resultInst = nullptr;
- auto dataPtr = builder.emitFieldAddress(
- getLoweredPtrLikeType(
- majorAddr->getDataType(),
- matrixTypeInfo->loweredInnerArrayType),
- majorGEP->getBase(),
- matrixTypeInfo->loweredInnerStructKey);
- if (getIntVal(matrixType->getLayout()) ==
- SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- List<IRInst*> args;
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- auto vector =
- builder.emitLoad(builder.emitElementAddress(dataPtr, i));
- auto element =
- builder.emitElementExtract(vector, majorGEP->getIndex());
- args.add(element);
- }
- resultInst = builder.emitMakeVector(
- builder.getVectorType(
- matrixType->getElementType(),
- (IRIntegerValue)args.getCount()),
- args);
- }
- else
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < colCount; i++)
{
+ auto vector =
+ builder.emitLoad(builder.emitElementAddress(dataPtr, i));
auto element =
- builder.emitElementAddress(dataPtr, majorGEP->getIndex());
- resultInst = builder.emitLoad(element);
+ builder.emitElementExtract(vector, majorGEP->getIndex());
+ args.add(element);
}
- user->replaceUsesWith(resultInst);
- user->removeAndDeallocate();
+ resultInst = builder.emitMakeVector(
+ builder.getVectorType(
+ matrixType->getElementType(),
+ (IRIntegerValue)args.getCount()),
+ args);
}
- break;
- case kIROp_Store:
+ else
{
- auto storeInst = cast<IRStore>(user);
- if (storeInst->getOperand(0) != majorAddr)
- break;
- auto dataPtr = builder.emitFieldAddress(
- getLoweredPtrLikeType(
- majorAddr->getDataType(),
- matrixTypeInfo->loweredInnerArrayType),
- majorGEP->getBase(),
- matrixTypeInfo->loweredInnerStructKey);
- if (getIntVal(matrixType->getLayout()) ==
- SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- auto vectorAddr = builder.emitElementAddress(dataPtr, i);
- auto elementAddr = builder.emitElementAddress(
- vectorAddr,
- majorGEP->getIndex());
- builder.emitStore(
- elementAddr,
- builder.emitElementExtract(storeInst->getVal(), i));
- }
- }
- else
- {
- auto rowAddr =
- builder.emitElementAddress(dataPtr, majorGEP->getIndex());
- builder.emitStore(rowAddr, storeInst->getVal());
- user->removeAndDeallocate();
- }
- break;
+ auto element =
+ builder.emitElementAddress(dataPtr, majorGEP->getIndex());
+ resultInst = builder.emitLoad(element);
}
- case kIROp_GetElementPtr:
+ user->replaceUsesWith(resultInst);
+ user->removeAndDeallocate();
+ }
+ break;
+ case kIROp_Store:
+ {
+ auto storeInst = cast<IRStore>(user);
+ if (storeInst->getOperand(0) != majorAddr)
+ break;
+ auto dataPtr = builder.emitFieldAddress(
+ getLoweredPtrLikeType(
+ majorAddr->getDataType(),
+ matrixTypeInfo->loweredInnerArrayType),
+ storageBase,
+ matrixTypeInfo->loweredInnerStructKey);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
{
- auto gep2 = cast<IRGetElementPtr>(user);
- auto rowIndex = majorGEP->getIndex();
- auto colIndex = gep2->getIndex();
- if (getIntVal(matrixType->getLayout()) ==
- SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ for (IRIntegerValue i = 0; i < colCount; i++)
{
- Swap(rowIndex, colIndex);
+ auto vectorAddr = builder.emitElementAddress(dataPtr, i);
+ auto elementAddr =
+ builder.emitElementAddress(vectorAddr, majorGEP->getIndex());
+ builder.emitStore(
+ elementAddr,
+ builder.emitElementExtract(storeInst->getVal(), i));
}
- auto dataPtr = builder.emitFieldAddress(
- getLoweredPtrLikeType(
- majorAddr->getDataType(),
- matrixTypeInfo->loweredInnerArrayType),
- majorGEP->getBase(),
- matrixTypeInfo->loweredInnerStructKey);
- auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex);
- auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex);
- gep2->replaceUsesWith(elementAddr);
- gep2->removeAndDeallocate();
- break;
}
- default:
- SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs "
- "storage lowering.");
+ else
+ {
+ auto rowAddr =
+ builder.emitElementAddress(dataPtr, majorGEP->getIndex());
+ builder.emitStore(rowAddr, storeInst->getVal());
+ user->removeAndDeallocate();
+ }
break;
}
- });
- }
+ case kIROp_GetElementPtr:
+ {
+ auto gep2 = cast<IRGetElementPtr>(user);
+ auto rowIndex = majorGEP->getIndex();
+ auto colIndex = gep2->getIndex();
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ Swap(rowIndex, colIndex);
+ }
+ auto dataPtr = builder.emitFieldAddress(
+ getLoweredPtrLikeType(
+ majorAddr->getDataType(),
+ matrixTypeInfo->loweredInnerArrayType),
+ storageBase,
+ matrixTypeInfo->loweredInnerStructKey);
+ auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex);
+ auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex);
+ gep2->replaceUsesWith(elementAddr);
+ gep2->removeAndDeallocate();
+ break;
+ }
+ default:
+ SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs "
+ "storage lowering.");
+ break;
+ }
+ });
+ if (!majorAddr->hasUses())
+ majorAddr->removeAndDeallocate();
}
};
diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h
index 2c69c5476..9d6e53609 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.h
+++ b/source/slang/slang-ir-lower-buffer-element-type.h
@@ -10,7 +10,7 @@ struct IRType;
struct BufferElementTypeLoweringOptions
{
- bool lowerBufferPointer = false;
+ bool lowerBufferPointer = true;
// For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers.
bool use16ByteArrayElementForConstantBuffer = false;
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index fd950b91a..e66617e72 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -135,6 +135,16 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner
case kIROp_RWStructuredBufferGetElementPtr:
outAddressSpace = AddressSpace::Global;
return true;
+ case kIROp_Load:
+ {
+ auto addrSpace = getAddressSpaceFromVarType(inst->getDataType());
+ if (addrSpace != AddressSpace::Generic)
+ {
+ outAddressSpace = addrSpace;
+ return true;
+ }
+ }
+ return false;
default:
return false;
}
@@ -256,10 +266,13 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
legalizeEntryPointVaryingParamsForMetal(module, sink, entryPoints);
+ processInst(module->getModuleInst(), sink);
+}
+
+void specializeAddressSpaceForMetal(IRModule* module)
+{
MetalAddressSpaceAssigner metalAddressSpaceAssigner;
specializeAddressSpace(module, &metalAddressSpaceAssigner);
-
- processInst(module->getModuleInst(), sink);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-metal-legalize.h b/source/slang/slang-ir-metal-legalize.h
index e49c1b9e0..98c19de60 100644
--- a/source/slang/slang-ir-metal-legalize.h
+++ b/source/slang/slang-ir-metal-legalize.h
@@ -7,4 +7,6 @@ namespace Slang
class DiagnosticSink;
void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink);
+void specializeAddressSpaceForMetal(IRModule* module);
+
} // namespace Slang
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 94bb6b67c..4c10cf246 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -196,6 +196,9 @@ static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func)
continue;
}
+ if (destPtr->findDecorationImpl(kIROp_DisableCopyEliminationDecoration))
+ continue;
+
// Check if we're storing a load result
auto loadInst = as<IRLoad>(storedValue);
if (!loadInst)
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp
index 29f1ec516..c4a155eec 100644
--- a/source/slang/slang-ir-specialize-address-space.cpp
+++ b/source/slang/slang-ir-specialize-address-space.cpp
@@ -168,6 +168,7 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
{
case kIROp_Var:
case kIROp_RWStructuredBufferGetElementPtr:
+ case kIROp_Load:
{
// The address space of these insts should be assigned by the initial
// address space assigner.
@@ -204,16 +205,6 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
}
}
break;
- case kIROp_Load:
- {
- if (auto addrSpace =
- mapVarValueToAddrSpace.tryGetValue(inst->getOperand(0)))
- {
- mapInstToAddrSpace[inst] = *addrSpace;
- changed = true;
- }
- }
- break;
case kIROp_Param:
if (!isFirstBlock)
{
@@ -248,22 +239,24 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
if (callee)
{
List<AddressSpace> argAddrSpaces;
- bool fullySpecialized = true;
+ bool hasSpecializableArg = false;
for (UInt i = 0; i < callInst->getArgCount(); i++)
{
auto arg = callInst->getArg(i);
- auto argAddrSpace = getAddrSpace(arg);
argAddrSpaces.add(getAddrSpace(arg));
- if (argAddrSpace == AddressSpace::Generic &&
- as<IRPtrTypeBase>(arg->getDataType()))
+ if (as<IRPtrTypeBase>(arg->getDataType()))
{
- fullySpecialized = false;
- break;
+ hasSpecializableArg = true;
}
}
- if (!fullySpecialized)
+ if (!hasSpecializableArg)
+ {
+ workList.add(callee);
+ break;
+ }
+ // If callee doesn't have a body, don't specialize.
+ if (!callee->getFirstBlock())
break;
-
FuncSpecializationKey key(callee, argAddrSpaces);
IRFunc* specializedCallee = nullptr;
if (IRFunc** specializedFunc =
@@ -484,4 +477,14 @@ void propagateAddressSpaceFromInsts(List<IRInst*>&& workList)
}
}
+AddressSpace NoOpInitialAddressSpaceAssigner::getAddressSpaceFromVarType(IRInst*)
+{
+ return AddressSpace::Generic;
+}
+
+AddressSpace NoOpInitialAddressSpaceAssigner::getLeafInstAddressSpace(IRInst*)
+{
+ return AddressSpace::Generic;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h
index 7e5f0fd9b..89145cf87 100644
--- a/source/slang/slang-ir-specialize-address-space.h
+++ b/source/slang/slang-ir-specialize-address-space.h
@@ -24,6 +24,13 @@ struct InitialAddressSpaceAssigner
virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) = 0;
};
+struct NoOpInitialAddressSpaceAssigner : public InitialAddressSpaceAssigner
+{
+ virtual bool tryAssignAddressSpace(IRInst*, AddressSpace&) { return false; }
+ virtual AddressSpace getAddressSpaceFromVarType(IRInst* type);
+ virtual AddressSpace getLeafInstAddressSpace(IRInst* inst);
+};
+
/// Propagate address space information through the IR module.
/// Specialize functions with reference/pointer parameters to use the correct address space
/// based on the address space of the arguments.
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 2c4bd11cc..f795a6559 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -893,18 +893,22 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto funcType = as<IRFuncType>(funcValue->getFullType());
+ bool argsChanged = false;
for (UInt i = 0; i < inst->getArgCount(); i++)
{
auto arg = inst->getArg(i);
auto paramType = funcType->getParamType(i);
- if (as<IRPtrType>(paramType))
+ if (auto ptrType = as<IRPtrType>(paramType))
{
- // If the parameter has an explicit pointer type,
- // then we know the user is using the variable pointer
- // capability to pass a true pointer.
- // In this case we should not rewrite the call.
- newArgs.add(arg);
- continue;
+ if (ptrType->getAddressSpace() == AddressSpace::UserPointer)
+ {
+ // If the parameter has an explicit pointer type,
+ // then we know the user is using the variable pointer
+ // capability to pass a true pointer.
+ // In this case we should not rewrite the call.
+ newArgs.add(arg);
+ continue;
+ }
}
auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
if (!as<IRPtrTypeBase>(arg->getDataType()))
@@ -953,13 +957,26 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// If we reach here, we need to allocate a temp var.
auto tempVar = builder.emitVar(ptrType->getValueType());
+ builder.addDecoration(tempVar, kIROp_DisableCopyEliminationDecoration);
auto load = builder.emitLoad(arg);
builder.emitStore(tempVar, load);
newArgs.add(tempVar);
+ argsChanged = true;
+
+ // We may need to write the value back to the original pointer argument
+ // after the call.
+ //
+ // If callee doesn't modify the memory location, no need to write back.
+ if (funcType && funcType->getParamCount() > i &&
+ as<IRConstRefType>(funcType->getParamType(i)))
+ continue;
+ // If the buffer location is immutable, don't write back.
+ if (isPointerToImmutableLocation(root))
+ continue;
writeBacks.add(WriteBackPair{arg, tempVar});
}
SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount());
- if (writeBacks.getCount())
+ if (argsChanged)
{
auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs);
for (auto wb : writeBacks)
@@ -2297,19 +2314,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// so we need to update the function types to match that.
updateFunctionTypes();
- // Lower all loads/stores from buffer pointers to use correct storage types.
- // We didn't do the lowering for buffer pointers because we don't know which pointer
- // types are actual storage buffer pointers until we propagated the address space of
- // pointers in this pass. In the future we should consider separate out IRAddress as
- // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can
- // safely lower the pointer load stores early together with other buffer types.
- BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
- bufferElementTypeLoweringOptions.lowerBufferPointer = true;
- lowerBufferElementTypeToStorageType(
- m_sharedContext->m_targetProgram,
- m_module,
- bufferElementTypeLoweringOptions);
-
// Look for structs that are both used as fields and marked with Block
// decorations, and move the Block decoration to a wrapper struct.
legalizeStructBlocks();
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp
index 9328a1de1..d34b3d25b 100644
--- a/source/slang/slang-ir-transform-params-to-constref.cpp
+++ b/source/slang/slang-ir-transform-params-to-constref.cpp
@@ -31,10 +31,7 @@ struct TransformParamsToConstRefContext
case kIROp_StructType:
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
- case kIROp_VectorType:
- case kIROp_MatrixType:
case kIROp_TupleType:
- case kIROp_CoopVectorType:
// valid type, continue to check
break;
default:
@@ -44,58 +41,121 @@ struct TransformParamsToConstRefContext
return true;
}
- void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ void rewriteValueUsesToAddrUses(IRInst* newAddrInst)
{
- // Traverse the uses of our updated params to rewrite them.
- // Assume a `in` parameter has been converted to a `constref` parameter.
- for (auto param : updatedParams)
+ HashSet<IRInst*> workListSet;
+ workListSet.add(newAddrInst);
+ List<IRInst*> workList;
+ workList.add(newAddrInst);
+ auto _addToWorkList = [&](IRInst* inst)
+ {
+ if (workListSet.add(inst))
+ workList.add(inst);
+ };
+ for (Index i = 0; i < workList.getCount(); i++)
{
+ auto inst = workList[i];
traverseUses(
- param,
+ inst,
[&](IRUse* use)
{
auto user = use->getUser();
+ if (workListSet.contains(user))
+ return;
switch (user->getOp())
{
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
fieldExtract->getField());
- auto loadInst = builder.emitLoad(fieldAddr);
- fieldExtract->replaceUsesWith(loadInst);
- fieldExtract->removeAndDeallocate();
- break;
+ fieldExtract->replaceUsesWith(fieldAddr);
+ _addToWorkList(fieldAddr);
+ return;
}
case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
getElement->getIndex());
- auto loadInst = builder.emitLoad(elemAddr);
- getElement->replaceUsesWith(loadInst);
- getElement->removeAndDeallocate();
- break;
+ getElement->replaceUsesWith(elemAddr);
+ _addToWorkList(elemAddr);
+ return;
}
- default:
+ case kIROp_Store:
{
- // Insert a load before the user and replace the user with the load
- builder.setInsertBefore(user);
- auto loadInst = builder.emitLoad(param);
- use->set(loadInst);
+ // If the current value is being stored into a write-once temp var that
+ // is immediately passed into a constref location in a call, we can get
+ // rid of the temp var and replace it with `inst` directly.
+ // (such temp var can be introduced during `updateCallSites` when we
+ // were processing the callee.)
+ //
+ auto dest = as<IRStore>(user)->getPtr();
+ if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration))
+ {
+ user->removeAndDeallocate();
+ dest->replaceUsesWith(inst);
+ dest->removeAndDeallocate();
+ return;
+ }
break;
}
}
+ // Insert a load before the user and replace the user with the load
+ builder.setInsertBefore(user);
+ auto loadInst = builder.emitLoad(inst);
+ use->set(loadInst);
});
}
}
+ void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ {
+ // Traverse the uses of our updated params to rewrite them.
+ // Assume a `in` parameter has been converted to a `constref` parameter.
+ for (auto param : updatedParams)
+ {
+ rewriteValueUsesToAddrUses(param);
+ }
+ }
+
+ // Check if `load` is an `IRLoad(addr)` where `addr` is a immutable location.
+ IRInst* isLoadFromImmutableAddress(IRInst* load)
+ {
+ if (load->getOp() != kIROp_Load)
+ return nullptr;
+ auto addr = load->getOperand(0);
+ auto root = getRootAddr(addr);
+ if (!root)
+ return nullptr;
+ if (!root->getDataType())
+ return nullptr;
+ switch (root->getDataType()->getOp())
+ {
+ case kIROp_ConstantBufferType:
+ case kIROp_ConstRefType:
+ case kIROp_ParameterBlockType:
+ return addr;
+ default:
+ // Note that we should in general not assume a read-only StructuredBuffer or
+ // a pointer with read-only access as an immutable location due to potential aliasing.
+ // We could introduce a compiler flag to turn on optimizations on these buffer types
+ // assuming there is no aliasing.
+ break;
+ }
+ return nullptr;
+ }
+
// Update call sites to pass an address instead of value for each updated-param
void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
{
@@ -119,10 +179,19 @@ struct TransformParamsToConstRefContext
newArgs.add(arg);
continue;
}
-
- auto tempVar = builder.emitVar(arg->getFullType());
- builder.emitStore(tempVar, arg);
- newArgs.add(tempVar);
+ if (auto addr = isLoadFromImmutableAddress(arg))
+ {
+ // If existing argument is a load from an immutable buffer address,
+ // we can pass in the address as is, without making a temporary copy.
+ newArgs.add(addr);
+ }
+ else
+ {
+ auto tempVar = builder.emitVar(arg->getFullType());
+ builder.addDecoration(tempVar, kIROp_TempCallArgImmutableVarDecoration);
+ builder.emitStore(tempVar, arg);
+ newArgs.add(tempVar);
+ }
}
// Create new call with updated arguments
@@ -177,6 +246,25 @@ struct TransformParamsToConstRefContext
{
HashSet<IRParam*> updatedParams;
+ // If the function is used in any way that is not understood by the
+ // compiler, do not modify it.
+ // For example, if the function is used as callback, we must preserve
+ // its signature.
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (as<IRDecoration>(user))
+ continue;
+ if (auto call = as<IRCall>(user))
+ {
+ if (call->getCalleeUse() == use)
+ continue;
+ }
+ // If we reach here, we encountered a non-call use of the func,
+ // we will stop processing.
+ return;
+ }
+
// First pass: Transform parameter types
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
{
@@ -192,7 +280,7 @@ struct TransformParamsToConstRefContext
// This allows us to pass the address of variables directly into a function,
// giving us the choice to remove copies into a parameter.
auto paramType = param->getDataType();
- auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
+ auto constRefType = builder.getConstRefType(paramType, AddressSpace::Generic);
param->setFullType(constRefType);
changed = true;
@@ -205,6 +293,8 @@ struct TransformParamsToConstRefContext
return;
}
+ fixUpFuncType(func);
+
// Second pass: Update function body according to the new `constref` parameters
rewriteParamUseSitesToSupportConstRefUsage(updatedParams);
diff --git a/source/slang/slang-ir-undo-param-copy.cpp b/source/slang/slang-ir-undo-param-copy.cpp
index d8aac7201..75ee8a03a 100644
--- a/source/slang/slang-ir-undo-param-copy.cpp
+++ b/source/slang/slang-ir-undo-param-copy.cpp
@@ -6,7 +6,7 @@
namespace Slang
{
-// This pass transforms variables decorated with TempCallArgVarDecoration
+// This pass transforms variables decorated with TempCallArgImmutableVarDecoration
// by replacing them with direct references to the original parameters.
// This is important for CUDA/OptiX targets where functions like 'IgnoreHit'
// can prevent copy-back operations from executing.
@@ -52,7 +52,17 @@ struct UndoParameterCopyVisitor
{
if (auto varInst = as<IRVar>(inst))
{
- if (varInst->findDecoration<IRTempCallArgVarDecoration>())
+ bool isTempCallArgVar = false;
+ for (auto decor : varInst->getDecorations())
+ {
+ if (as<IRTempCallArgImmutableVarDecoration>(decor) ||
+ as<IRTempCallArgVarDecoration>(decor))
+ {
+ isTempCallArgVar = true;
+ break;
+ }
+ }
+ if (isTempCallArgVar)
{
IRStore* initializingStore = nullptr;
IRInst* originalParamPtr = nullptr;
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 9b852b803..8584ea95e 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2128,6 +2128,40 @@ IRType* getIRVectorBaseType(IRType* type)
return as<IRVectorType>(type)->getElementType();
}
+IRType* getElementType(IRBuilder& builder, IRType* valueType)
+{
+ valueType = (IRType*)unwrapAttributedType(valueType);
+ if (auto arrayType = as<IRArrayTypeBase>(valueType))
+ {
+ return arrayType->getElementType();
+ }
+ else if (auto vectorType = as<IRVectorType>(valueType))
+ {
+ return vectorType->getElementType();
+ }
+ else if (auto basicType = as<IRBasicType>(valueType))
+ {
+ return basicType;
+ }
+ else if (auto coopVecType = as<IRCoopVectorType>(valueType))
+ {
+ return coopVecType->getElementType();
+ }
+ else if (auto matrixType = as<IRMatrixType>(valueType))
+ {
+ return builder.getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
+ }
+ else if (auto coopMatType = as<IRCoopMatrixType>(valueType))
+ {
+ return coopMatType->getElementType();
+ }
+ else if (auto hlslInputPatchType = as<IRHLSLInputPatchType>(valueType))
+ {
+ return hlslInputPatchType->getElementType();
+ }
+ return nullptr;
+}
+
Int getSpecializationConstantId(IRGlobalParam* param)
{
auto layout = findVarLayout(param);
@@ -2483,4 +2517,47 @@ bool isIROpaqueType(IRType* type)
}
}
+bool isPointerToImmutableLocation(IRInst* loc)
+{
+ switch (loc->getOp())
+ {
+ case kIROp_GetStructuredBufferPtr:
+ case kIROp_ImageSubscript:
+ return isPointerToImmutableLocation(loc->getOperand(0));
+ default:
+ break;
+ }
+
+ auto type = loc->getDataType();
+ if (!type)
+ return false;
+
+ switch (type->getOp())
+ {
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLByteAddressBufferType:
+ case kIROp_ConstantBufferType:
+ case kIROp_ParameterBlockType:
+ return true;
+ default:
+ break;
+ }
+
+ if (auto textureType = as<IRTextureType>(type))
+ return textureType->getAccess() == SLANG_RESOURCE_ACCESS_READ;
+
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::BuiltinInput:
+ case AddressSpace::Input:
+ case AddressSpace::MetalObjectData:
+ case AddressSpace::Uniform:
+ case AddressSpace::UniformConstant:
+ return true;
+ }
+ }
+ return false;
+}
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index b7aafea56..c0410fa3c 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -376,6 +376,10 @@ void verifyComputeDerivativeGroupModifiers(
int getIRVectorElementSize(IRType* type);
IRType* getIRVectorBaseType(IRType* type);
+// Retrieves the element type of a pointer, buffer, array, vector or matrix type.
+// This is the result type of a ElementExtract operation on a value of `type`.
+IRType* getElementType(IRBuilder& builder, IRType* type);
+
Int getSpecializationConstantId(IRGlobalParam* param);
void legalizeDefUse(IRGlobalValueWithCode* func);
@@ -418,6 +422,22 @@ IRType* getUnsignedTypeFromSignedType(IRBuilder* builder, IRType* type);
bool isSignedType(IRType* type);
bool isIROpaqueType(IRType* type);
+
+// Returns true if the memory location pointed to by `ptrInst` is immutable.
+// An immutable location is the memory region that can't be modified by the user code.
+// Examples are ConstantBuffer and shader resource contents(e.g. StructuredBuffer).
+// Note that this is to be disguished from the access qualifier of the pointer itself,
+// e.g. a `ptrInst` of type `Ptr<T, Access.Read>` may still point to a mutable location,
+// so this function returns false in that case.
+bool isPointerToImmutableLocation(IRInst* ptrInst);
+
+// Check if `use` is the `baseAddr` operand of a GetElement/FieldExtract inst.
+// This is true if `use` is the first operand of the user inst.
+inline bool isUseBaseAddrOperand(IRUse* use, IRInst* user)
+{
+ return user->getOperandUse(0) == use;
+}
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e9d1a1199..92543c952 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -386,6 +386,20 @@ IRType* tryGetPointedToType(IRBuilder* builder, IRType* type)
return nullptr;
}
+IRType* tryGetPointedToOrBufferElementType(IRBuilder* builder, IRType* type)
+{
+ if (auto rateQualType = as<IRRateQualifiedType>(type))
+ {
+ type = rateQualType->getValueType();
+ }
+ auto resultType = tryGetPointedToType(builder, type);
+ if (resultType)
+ return resultType;
+ if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type))
+ return structuredBufferType->getElementType();
+ return nullptr;
+}
+
// IRBlock
@@ -5480,39 +5494,19 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
}
IRType* type = nullptr;
valueType = unwrapAttributedType(valueType);
- if (auto arrayType = as<IRArrayTypeBase>(valueType))
- {
- type = arrayType->getElementType();
- }
- else if (auto vectorType = as<IRVectorType>(valueType))
- {
- type = vectorType->getElementType();
- }
- else if (auto coopVecType = as<IRCoopVectorType>(valueType))
- {
- type = coopVecType->getElementType();
- }
- else if (auto matrixType = as<IRMatrixType>(valueType))
- {
- type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
- }
- else if (auto coopMatType = as<IRCoopMatrixType>(valueType))
- {
- type = coopMatType->getElementType();
- }
- else if (const auto basicType = as<IRBasicType>(valueType))
+ if (as<IRBasicType>(valueType))
{
// HLSL support things like float.x, in which case we just return the base pointer.
return basePtr;
}
- else if (const auto tupleType = as<IRTupleType>(valueType))
- {
- SLANG_ASSERT(as<IRIntLit>(index));
- type = (IRType*)tupleType->getOperand(getIntVal(index));
- }
- else if (auto hlslInputPatchType = as<IRHLSLInputPatchType>(valueType))
+ type = getElementType(*this, (IRType*)valueType);
+ if (!type)
{
- type = hlslInputPatchType->getElementType();
+ if (const auto tupleType = as<IRTupleType>(valueType))
+ {
+ SLANG_ASSERT(as<IRIntLit>(index));
+ type = (IRType*)tupleType->getOperand(getIntVal(index));
+ }
}
SLANG_RELEASE_ASSERT(type);
@@ -6179,6 +6173,24 @@ IRInst* IRBuilder::emitCastIntToPtr(IRType* ptrType, IRInst* val)
return inst;
}
+IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType)
+{
+ if (type == val->getDataType())
+ return val;
+ IRInst* args[] = {val, bufferType};
+ return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args);
+}
+
+IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref(
+ IRType* type,
+ IRInst* val,
+ IRInst* bufferType)
+{
+ IRInst* args[] = {val, bufferType};
+ return (IRCastStorageToLogicalDeref*)
+ emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args);
+}
+
IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type)
{
auto inst = createInst<IRGlobalConstant>(this, kIROp_GlobalConstant, type);
@@ -6627,6 +6639,7 @@ IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText)
IRInst* IRBuilder::emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index)
{
const auto sbt = cast<IRHLSLStructuredBufferTypeBase>(structuredBuffer->getDataType());
+ SLANG_ASSERT(sbt);
const auto t = getPtrType(sbt->getElementType());
IRInst* const operands[2] = {structuredBuffer, index};
const auto i = createInst<IRRWStructuredBufferGetElementPtr>(
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 69a000b81..161e70b25 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -695,6 +695,12 @@ struct IRInst
IRUse* getOperands();
+ IRUse* getOperandUse(UInt index)
+ {
+ SLANG_ASSERT(index < getOperandCount());
+ return getOperands() + index;
+ }
+
IRInst* getOperand(UInt index)
{
SLANG_ASSERT(index < getOperandCount());
@@ -1533,7 +1539,7 @@ struct IRUniformParameterGroupType : IRParameterGroupType
FIDDLE()
-struct IRGLSLShaderStorageBufferType : IRBuiltinGenericType
+struct IRGLSLShaderStorageBufferType : IRPointerLikeType
{
FIDDLE(leafInst())
IRType* getDataLayout() { return (IRType*)getOperand(1); }
@@ -1760,6 +1766,8 @@ struct IRGetStringHash : IRInst
/// The given IR `builder` will be used if new instructions need to be created.
IRType* tryGetPointedToType(IRBuilder* builder, IRType* type);
+IRType* tryGetPointedToOrBufferElementType(IRBuilder* builder, IRType* type);
+
FIDDLE()
struct IRFuncType : IRType
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ca32b0f2d..5d83a351f 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2062,6 +2062,21 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
}
}
+ IRInst* convertToUInt64Value(IRInst* inst)
+ {
+ if (inst->getOp() == kIROp_IntLit)
+ {
+ auto constVal = as<IRConstant>(inst);
+ return context->irBuilder->getIntValue(
+ context->irBuilder->getUInt64Type(),
+ constVal->value.intVal);
+ }
+ else
+ {
+ return context->irBuilder->emitCast(context->irBuilder->getUInt64Type(), inst);
+ }
+ }
+
IRType* visitPtrType(PtrType* type)
{
auto astValueType = type->getValueType();
@@ -2072,12 +2087,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
if (auto astAccessQualifier = type->getAccessQualifier())
{
- accessQualifier = getSimpleVal(context, lowerVal(context, astAccessQualifier));
+ accessQualifier =
+ convertToUInt64Value(getSimpleVal(context, lowerVal(context, astAccessQualifier)));
}
if (auto astAddrSpace = type->getAddressSpace())
{
- addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace));
+ addrSpace =
+ convertToUInt64Value(getSimpleVal(context, lowerVal(context, astAddrSpace)));
}
else
{
diff --git a/tests/compute/byte-address-buffer-array.slang b/tests/compute/byte-address-buffer-array.slang
index 90cdb2261..58862d1ac 100644
--- a/tests/compute/byte-address-buffer-array.slang
+++ b/tests/compute/byte-address-buffer-array.slang
@@ -1,7 +1,7 @@
// byte-address-buffer-array.slang
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -d3d12 -profile cs_6_0 -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -cuda -profile cs_6_0 -shaderobj -output-using-type
-//DISABLED_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK2):-target hlsl -entry computeMain -stage compute
//TEST:SIMPLE(filecheck=CHECK3):-target spirv -entry computeMain -stage compute
@@ -10,7 +10,7 @@
// Confirm compilation of `(RW)ByteAddressBuffer` with aligned load / stores to wider data types.
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=buffer
-[vk::binding(2, 3)] RWByteAddressBuffer buffer;
+RWByteAddressBuffer buffer;
struct Block {
float4 val[2];
};
diff --git a/tests/optimization/arrray-storage-lowering.slang b/tests/optimization/arrray-storage-lowering.slang
new file mode 100644
index 000000000..42bb8f127
--- /dev/null
+++ b/tests/optimization/arrray-storage-lowering.slang
@@ -0,0 +1,42 @@
+// TEST:SIMPLE(filecheck=SPV): -target spirv
+
+// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type -emit-spirv-directly
+
+struct DoubleNested
+{
+ int4x3 matrix;
+ int getMatVal(int i, int j) { return matrix[i][j]; }
+}
+
+struct Nested
+{
+ bool values[4];
+ DoubleNested doubleNested;
+ int getVal(int id) { return (int)values[0] + doubleNested.getMatVal(0, 1); }
+}
+
+struct Params
+{
+ Nested nested;
+
+ int getVal(int id) { return nested.getVal(id) + nested.getVal(id + 1); }
+}
+
+// TEST_INPUT: set outputBuffer = out ubuffer(data=[0], stride=4)
+RWStructuredBuffer<int4> outputBuffer;
+
+// TEST_INPUT:set gParams = cbuffer(data=[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1])
+ConstantBuffer<Params> gParams;
+
+// TEST_INPUT: set gDoubleNested = ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12])
+uniform DoubleNested *gDoubleNested;
+
+// CHECK: 9
+
+[numthreads(1,1,1)]
+void computeMain(int id: SV_DispatchThreadID)
+{
+ outputBuffer[0].xyz = gParams.getVal(id) + gDoubleNested.getMatVal(1, 1);
+}
+
+// SPV-NOT: OpCompositeConstruct
diff --git a/tests/optimization/get-array-element.slang b/tests/optimization/get-array-element.slang
new file mode 100644
index 000000000..16a71aee2
--- /dev/null
+++ b/tests/optimization/get-array-element.slang
@@ -0,0 +1,17 @@
+//TEST:SIMPLE(filecheck=CHECK):-target spirv
+
+int test(int arr[32]) {
+ int sum = 0;
+ for (int i =0; i < 32; i++) sum += arr[i];
+ return sum;
+}
+
+uniform int gArr[32];
+uniform int* result;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK-NOT: OpCompositeConstruct
+ *result = test(gArr);
+} \ No newline at end of file
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang b/tests/pipeline/rasterization/get-attribute-at-vertex.slang
index 342796b90..f73ebe86f 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang
@@ -9,8 +9,11 @@
// CHECK-SPIRV: OpExtension "SPV_KHR_fragment_shader_barycentric"
// CHECK-SPIRV: OpEntryPoint Fragment %main "main"
-// CHECK-SPIRV: OpDecorate %{{.*}} BuiltIn BaryCoordKHR
-// CHECK-SPIRV: OpDecorate %{{.*}} BuiltIn BaryCoordNoPerspKHR
+
+// CHECK-SPIRV-DAG: OpDecorate %{{.*}} BuiltIn BaryCoordKHR
+
+// CHECK-SPIRV-DAG: OpDecorate %{{.*}} BuiltIn BaryCoordNoPerspKHR
+
// CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_0
// CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_1
// CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_2
diff --git a/tests/spirv/aligned-load-store.slang b/tests/spirv/aligned-load-store.slang
index c2f50b66c..8131e1cc7 100644
--- a/tests/spirv/aligned-load-store.slang
+++ b/tests/spirv/aligned-load-store.slang
@@ -4,8 +4,6 @@
// CHECK: OpStore {{.*}} Aligned 16
// CHECK: OpLoad {{.*}} Aligned 16
-// CHECK: OpLoad {{.*}} Aligned 16
-// CHECK: OpStore {{.*}} Aligned 16
// CHECK: OpStore {{.*}} Aligned 16
uniform float4* data;
diff --git a/tests/spirv/buffer-pointer-matrix-layout.slang b/tests/spirv/buffer-pointer-matrix-layout.slang
index cbb8f2857..4d80419e4 100644
--- a/tests/spirv/buffer-pointer-matrix-layout.slang
+++ b/tests/spirv/buffer-pointer-matrix-layout.slang
@@ -1,33 +1,50 @@
-//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -stage compute -entry main -matrix-layout-column-major
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -output-using-type -xslang -matrix-layout-column-major -emit-spirv-directly
-// CHECK: OpLoad {{.*}} Aligned 4
+// TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0 12.0],stride=4)
+uniform float3x4 *ptr;
-struct Push
-{
- float3x4* ptr;
-};
+// TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0],stride=4)
+RWStructuredBuffer<float> outputBuffer;
-[[vk::push_constant]] Push push;
[shader("compute")]
[numthreads(1, 1, 1)]
-void main(uint3 dtid : SV_DispatchThreadID)
-{
+void computeMain(uint3 dtid: SV_DispatchThreadID)
+{
// This matrix is in memry column major. Slang respects this here and load it properly!
- float3x4 correctly_read_matrix = *push.ptr;
- printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n",
- correctly_read_matrix[0][0], correctly_read_matrix[0][1], correctly_read_matrix[0][2], correctly_read_matrix[0][3],
- correctly_read_matrix[1][0], correctly_read_matrix[1][1], correctly_read_matrix[1][2], correctly_read_matrix[1][3]
- );
- printf("(%f,%f,%f,%f)\n\n",
- correctly_read_matrix[2][0], correctly_read_matrix[2][1], correctly_read_matrix[2][2], correctly_read_matrix[2][3]
- );
- // With this syntax however, Slang ignores the column major setting and loads it as it it was row major!
- float3x4 broken_matrix = push.ptr[0];
- printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n",
- broken_matrix[0][0], broken_matrix[0][1], broken_matrix[0][2], broken_matrix[0][3],
- broken_matrix[1][0], broken_matrix[1][1], broken_matrix[1][2], broken_matrix[1][3]
- );
- printf("(%f,%f,%f,%f)\n\n",
- broken_matrix[2][0], broken_matrix[2][1], broken_matrix[2][2], broken_matrix[2][3]
- );
+ float3x4 correctly_read_matrix = *ptr;
+ outputBuffer[0] = correctly_read_matrix[0][0];
+ outputBuffer[1] = correctly_read_matrix[0][1];
+ outputBuffer[2] = correctly_read_matrix[0][2];
+ outputBuffer[3] = correctly_read_matrix[0][3];
+ outputBuffer[4] = correctly_read_matrix[1][0];
+ outputBuffer[5] = correctly_read_matrix[1][1];
+ outputBuffer[6] = correctly_read_matrix[1][2];
+ outputBuffer[7] = correctly_read_matrix[1][3];
+ // CHECK: 1.0
+ // CHECK: 4.0
+ // CHECK: 7.0
+ // CHECK: 10.0
+ // CHECK: 2.0
+ // CHECK: 5.0
+ // CHECK: 8.0
+ // CHECK: 11.0
+
+ // With this syntax however, Slang was ignoring the column major setting and loads it as it it was row major!
+ float3x4 broken_matrix = ptr[0];
+ outputBuffer[8] = broken_matrix[0][0];
+ outputBuffer[9] = broken_matrix[0][1];
+ outputBuffer[10] = broken_matrix[0][2];
+ outputBuffer[11] = broken_matrix[0][3];
+ outputBuffer[12] = broken_matrix[1][0];
+ outputBuffer[13] = broken_matrix[1][1];
+ outputBuffer[14] = broken_matrix[1][2];
+ outputBuffer[15] = broken_matrix[1][3];
+ // CHECK: 1.0
+ // CHECK: 4.0
+ // CHECK: 7.0
+ // CHECK: 10.0
+ // CHECK: 2.0
+ // CHECK: 5.0
+ // CHECK: 8.0
+ // CHECK: 11.0
} \ No newline at end of file
diff --git a/tests/spirv/geometry-shader-sub-func.slang b/tests/spirv/geometry-shader-sub-func.slang
index 6c6944f31..20634ea67 100644
--- a/tests/spirv/geometry-shader-sub-func.slang
+++ b/tests/spirv/geometry-shader-sub-func.slang
@@ -36,7 +36,7 @@ void main(
{
CoarseVertex coarseVertex = coarseVertices[ii];
RasterVertex rasterVertex;
- rasterVertex.position = coarseVertex.position;
+ rasterVertex.position = coarseVertex.position;
rasterVertex.color = coarseVertex.color;
rasterVertex.id = coarseVertex.id + primitiveID;
appendVertex(outputStream, rasterVertex);
diff --git a/tests/spirv/large-struct.slang b/tests/spirv/large-struct.slang
index 2d79c0aaf..e4cbd6d1c 100644
--- a/tests/spirv/large-struct.slang
+++ b/tests/spirv/large-struct.slang
@@ -1,5 +1,5 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type -xslang -g0
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-d3d12 -compute -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type
diff --git a/tests/spirv/pointer-2.slang b/tests/spirv/pointer-2.slang
index 1f2b2d0ea..b93ca32b4 100644
--- a/tests/spirv/pointer-2.slang
+++ b/tests/spirv/pointer-2.slang
@@ -1,6 +1,6 @@
-//TEST:SIMPLE(filecheck=CHECK_SPV): -entry vertexMain -stage vertex -emit-spirv-directly -target spirv
-//TEST:SIMPLE(filecheck=CHECK_SPV_VIA_GLSL): -entry vertexMain -stage vertex -emit-spirv-via-glsl -target spirv
-//TEST:SIMPLE(filecheck=CHECK_GLSL): -entry vertexMain -stage vertex -target glsl
+// TEST:SIMPLE(filecheck=CHECK_GLSL): -entry vertexMain -stage vertex -target glsl
+// TEST:SIMPLE(filecheck=CHECK_SPV): -entry vertexMain -stage vertex -emit-spirv-directly -target spirv
+// TEST:SIMPLE(filecheck=CHECK_SPV_VIA_GLSL): -entry vertexMain -stage vertex -emit-spirv-via-glsl -target spirv
struct Inner1
{
diff --git a/tests/spirv/spec-constant-operations.slang b/tests/spirv/spec-constant-operations.slang
index 86d16ef34..7c7b9bc60 100644
--- a/tests/spirv/spec-constant-operations.slang
+++ b/tests/spirv/spec-constant-operations.slang
@@ -11,8 +11,6 @@ RWStructuredBuffer<float> outputBuffer;
// CHECK-DAG: OpSpecConstant %float 1
// CHECK-DAG: OpSpecConstant %ulong 256
// CHECK-DAG: OpSpecConstant %float 100
-// CHECK-DAG: OpSpecConstantOp %half FConvert
-// CHECK-DAG: OpSpecConstantOp %int UConvert
// CHECK-NOT: OpSpecConstantOp {{.*}} FAdd
// CHECK-NOT: OpSpecConstantOp {{.*}} FSub
diff --git a/tests/spirv/spirv-debug-break.slang b/tests/spirv/spirv-debug-break.slang
index e57024037..67b3e975b 100644
--- a/tests/spirv/spirv-debug-break.slang
+++ b/tests/spirv/spirv-debug-break.slang
@@ -1,5 +1,5 @@
// spirv-instruction.slang
-//TEST(compute, vulkan):SIMPLE(filecheck=CHECK):-target glsl -entry computeMain -stage compute
+// TEST:SIMPLE(filecheck=CHECK):-target glsl -entry computeMain -stage compute
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
void _spvDebugBreak(int v);
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index bb213cdaa..c03afaa9e 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -1015,9 +1015,12 @@ struct ShaderInputLayoutParser
for (auto& line : lines)
{
lineNum++;
- if (line.startsWith("//TEST_INPUT:"))
+ if (!line.startsWith("//"))
+ continue;
+ line = line.getUnownedSlice().tail(2).trim();
+ if (line.startsWith("TEST_INPUT:"))
{
- auto lineContent = line.subString(13, line.getLength() - 13);
+ auto lineContent = line.subString(11, line.getLength() - 11);
Misc::TokenReader parser(lineContent);
try
{