From a6deb5ed82cb8fc6b4f4c5c5fee264e09f97ff89 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 29 Sep 2025 17:45:08 -0700 Subject: Rewriting the lower-buffer-element-type pass to avoid unnecessary packing/unpacking. (#8526) Part of the effort to improve the performance of generated SPIRV code. The existing lower-buffer-element-type pass works by loading the entire buffer element content from memory, and translate it to logical type stored in a local variable at the earliest reference of a buffer handle. This means that is can generate inefficient code that reads more than necessary. Consider this example: ``` struct BigStruct { bool values[1024]; } ConstantBuffer cb; void test(BigStruct v) { if (v.values[0]) { printf("ok"); } } [numthreads(1,1,1)] void computeMain() { test(cb); } ``` In IR, the `computeMain` function before lower-buffer-element-type pass is something like following: ``` func test: %v = param : BigStruct %barr = fieldExtract(%v, "values") %element = elementExtract(%barr, 0) ... // uses %element func computeMain: %v = load(cb) call %test %v ``` The existing lower-buffer-element-type pass will rewrite the bool array in `BigStruct` into `int` array so it is legal in SPIRV. However, it does so by inserting the translation on the first `load` of the constant buffer: ``` struct BigStruct_std430 { int values[1024]; } var cb : ConstantBuffer; func computeMain: %tmpVar : var call %unpackStorage(%tmpVar, cb) %v : BigStruct = load %tmpVar call %test %v ``` This means that the entire array will be loaded and translated to int, before calling `test`, which only uses one element. It turns out that the downstream compiler isn't always able to optimize out this inefficient translation/copy. This PR completely rewrites the way buffer-element-type lowering is handled to avoid producing this inefficient code. It works in two parts: first we turn on the `transformParamsToConstRef` pass for SPIRV target as well, so we will translate the `test` function to take the `v` parameter as `constref`. The second part is a redesigned buffer-element-type pass that defers the storage-type to logical-type translation until a value is actually used by a `load` instruction. In this example, after `transformParamsToConstRef`, the IR is: ``` func test: %v = param : ConstRef %barr = fieldAddr(%v, "values") %elementPtr = elementAddr(%barr, 0) %element = load(%elementPtr) ... // uses %element func computeMain: call %test %cb ``` The new `buffer-element-type-lowering` pass will take this IR, and insert translation at latest possible time across the entire call graph, and translate the IR into: ``` func test: %v = param : ConstRef %barr = fieldAddr(%v, "values") %elementPtr : ptr = elementAddr(%barr, 0) %element_int = load(%elementPtr) %element = cast(%element_int) : %bool ... // uses %element func computeMain: call %test %cb ``` In this new IR, there is no longer a load and conversion of the entire array. See new comment in `slang-ir-lower-buffer-element-type.cpp` for more details of how the pass works. This PR also address many other issues surfaced by turning on `transformParamsToConstRef` pass on SPIRV backend. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang-glslang/slang-glslang.cpp | 7 +- source/slang/core.meta.slang | 1 - source/slang/slang-emit-spirv.cpp | 2 +- source/slang/slang-emit.cpp | 37 +- source/slang/slang-ir-defer-buffer-load.cpp | 38 +- source/slang/slang-ir-glsl-legalize.cpp | 110 +- source/slang/slang-ir-insts-stable-names.lua | 6 +- source/slang/slang-ir-insts.h | 14 + source/slang/slang-ir-insts.lua | 12 + source/slang/slang-ir-legalize-types.cpp | 10 +- .../slang/slang-ir-lower-buffer-element-type.cpp | 1507 ++++++++++++++------ source/slang/slang-ir-lower-buffer-element-type.h | 2 +- source/slang/slang-ir-metal-legalize.cpp | 17 +- source/slang/slang-ir-metal-legalize.h | 2 + source/slang/slang-ir-redundancy-removal.cpp | 3 + source/slang/slang-ir-specialize-address-space.cpp | 39 +- source/slang/slang-ir-specialize-address-space.h | 7 + source/slang/slang-ir-spirv-legalize.cpp | 46 +- .../slang-ir-transform-params-to-constref.cpp | 142 +- source/slang/slang-ir-undo-param-copy.cpp | 14 +- source/slang/slang-ir-util.cpp | 77 + source/slang/slang-ir-util.h | 20 + source/slang/slang-ir.cpp | 69 +- source/slang/slang-ir.h | 10 +- source/slang/slang-lower-to-ir.cpp | 21 +- tests/compute/byte-address-buffer-array.slang | 4 +- tests/optimization/arrray-storage-lowering.slang | 42 + tests/optimization/get-array-element.slang | 17 + .../rasterization/get-attribute-at-vertex.slang | 7 +- tests/spirv/aligned-load-store.slang | 2 - tests/spirv/buffer-pointer-matrix-layout.slang | 69 +- tests/spirv/geometry-shader-sub-func.slang | 2 +- tests/spirv/large-struct.slang | 2 +- tests/spirv/pointer-2.slang | 6 +- tests/spirv/spec-constant-operations.slang | 2 - tests/spirv/spirv-debug-break.slang | 2 +- tools/render-test/shader-input-layout.cpp | 7 +- 37 files changed, 1754 insertions(+), 621 deletions(-) create mode 100644 tests/optimization/arrray-storage-lowering.slang create mode 100644 tests/optimization/get-array-element.slang diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp index 56be8b042..1c91a97bc 100644 --- a/source/slang-glslang/slang-glslang.cpp +++ b/source/slang-glslang/slang-glslang.cpp @@ -241,7 +241,12 @@ extern "C" #endif bool glslang_disassembleSPIRV(const uint32_t* contents, int contentsSize) { - return glslang_disassembleSPIRVWithResult(contents, contentsSize, nullptr); + char* result = nullptr; + auto succ = glslang_disassembleSPIRVWithResult(contents, contentsSize, &result); + if (result) + fprintf(stdout, "%s\n", result); + delete result; + return succ; } // Apply the SPIRV-Tools optimizer to generated SPIR-V based on the desired optimization level diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 43658563e..54e693117 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1312,7 +1312,6 @@ enum Access : uint64_t /// @param T The type of the value pointed to. /// @remarks `T* val` is equivalent to `Ptr val`. __magic_type(PtrType) -__intrinsic_type($(kIROp_PtrType)) struct Ptr< T, Access access = Access::ReadWrite, diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 7b1bd66d8..ea49cb08c 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -7038,7 +7038,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex getStructFieldId(baseStructType, as(fieldAddress->getField())), builder.getIntType()); SLANG_ASSERT(as(fieldAddress->getFullType())); - return emitOpInBoundsAccessChain( + return emitOpAccessChain( parent, fieldAddress, fieldAddress->getFullType(), diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index e1689ccfc..f1cc6090d 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -97,6 +97,7 @@ #include "slang-ir-restructure.h" #include "slang-ir-sccp.h" #include "slang-ir-simplify-for-emit.h" +#include "slang-ir-specialize-address-space.h" #include "slang-ir-specialize-arrays.h" #include "slang-ir-specialize-buffer-load-arg.h" #include "slang-ir-specialize-matrix-layout.h" @@ -1715,6 +1716,7 @@ Result linkAndOptimizeIR( if (targetProgram->getOptionSet().getBoolOption( CompilerOptionName::EnableExperimentalPasses)) introduceExplicitGlobalContext(irModule, target); + transformParamsToConstRef(irModule, codeGenContext->getSink()); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "EXPLICIT GLOBAL CONTEXT INTRODUCED"); #endif @@ -1812,11 +1814,11 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.meshOutput) legalizeMeshOutputTypes(irModule); - BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; - bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer = - isWGPUTarget(targetRequest); - lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions); - performForceInlining(irModule); + + // Lower all bit_cast operations on complex types into leaf-level + // bit_cast on basic types. + if (requiredLoweringPassSet.bitcast) + lowerBitCast(targetProgram, irModule, sink); // Rewrite functions that return arrays to return them via `out` parameter, // since our target languages doesn't allow returning arrays. @@ -1832,13 +1834,28 @@ Result linkAndOptimizeIR( rcpWOfPositionInput(irModule); } - // Lower all bit_cast operations on complex types into leaf-level - // bit_cast on basic types. - if (requiredLoweringPassSet.bitcast) - lowerBitCast(targetProgram, irModule, sink); - bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly(); + BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; + bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer = + isWGPUTarget(targetRequest); + lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions); + + // If we are generating code for glsl or metal, perform address space propagation now. + // For SPIRV, we will do that during spirv legalization that happens after + // `linkAndOptimizeIR`. + if (target == CodeGenTarget::GLSL) + { + NoOpInitialAddressSpaceAssigner addrSpaceAssigner; + specializeAddressSpace(irModule, &addrSpaceAssigner); + } + else if (isMetalTarget(targetRequest)) + { + specializeAddressSpaceForMetal(irModule); + } + + performForceInlining(irModule); + if (emitSpirvDirectly) { performIntrinsicFunctionInlining(irModule); diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index e71892fe5..51c6a161b 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -67,38 +67,6 @@ struct DeferBufferLoadContext return result; } - static bool isImmutableLocation(IRInst* loc) - { - switch (loc->getOp()) - { - case kIROp_GetStructuredBufferPtr: - case kIROp_ImageSubscript: - return isImmutableLocation(loc->getOperand(0)); - default: - break; - } - - auto type = loc->getDataType(); - if (!type) - return false; - - switch (type->getOp()) - { - case kIROp_HLSLStructuredBufferType: - case kIROp_HLSLByteAddressBufferType: - case kIROp_ConstantBufferType: - case kIROp_ParameterBlockType: - return true; - default: - break; - } - - if (auto textureType = as(type)) - return textureType->getAccess() == SLANG_RESOURCE_ACCESS_READ; - - return false; - } - static bool isImmutableBufferLoad(IRInst* inst) { // Note: we cannot defer loads from RWStructuredBuffer because there can be other @@ -111,7 +79,7 @@ struct DeferBufferLoadContext case kIROp_Load: { auto rootAddr = getRootAddr(inst->getOperand(0)); - return isImmutableLocation(rootAddr); + return isPointerToImmutableLocation(rootAddr); } default: return false; @@ -132,14 +100,14 @@ struct DeferBufferLoadContext align = load->findAttr(); if (!as(ptr->getParent())) { - builder.setInsertAfter(ptr); + setInsertAfterOrdinaryInst(&builder, ptr); IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType()); result = builder.emitLoad(valueType, ptr, align); mapPtrToValue[ptr] = result; } else { - builder.setInsertBefore(loadInst); + setInsertBeforeOrdinaryInst(&builder, loadInst); IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType()); result = builder.emitLoad(valueType, ptr, align); // Since we are inserting the load in a local scope, we can't register diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index a2f56cf7d..a79ca2379 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1079,6 +1079,12 @@ IRInst* getOrCreateBuiltinParamForHullShader( if (sysAttr->getName().caseInsensitiveEquals(builtinSemantic)) { outputControlPointIdParam = param; + if (as(outputControlPointIdParam->getDataType())) + { + IRBuilder builder(param); + setInsertAfterOrdinaryInst(&builder, param); + outputControlPointIdParam = builder.emitLoad(param); + } break; } } @@ -2348,11 +2354,11 @@ ScalarizedVal getSubscriptVal( auto inputAdapter = val.impl.as(); RefPtr resultAdapter = new ScalarizedTypeAdapterValImpl(); - resultAdapter->pretendType = inputAdapter->pretendType; - resultAdapter->actualType = inputAdapter->actualType; + resultAdapter->pretendType = elementType; + resultAdapter->actualType = getElementType(*builder, inputAdapter->actualType); resultAdapter->val = - getSubscriptVal(builder, inputAdapter->actualType, inputAdapter->val, indexVal); + getSubscriptVal(builder, resultAdapter->actualType, inputAdapter->val, indexVal); return ScalarizedVal::typeAdapter(resultAdapter); } @@ -3127,7 +3133,7 @@ void tryReplaceUsesOfStageInput( { auto user = use->getUser(); IRBuilder builder(user); - builder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(&builder, user); builder.replaceOperand(use, val.irValue); }); } @@ -3155,7 +3161,7 @@ void tryReplaceUsesOfStageInput( return; } IRBuilder builder(user); - builder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(&builder, user); if (needMaterialize) { auto materializedVal = materializeValue(&builder, val); @@ -3176,22 +3182,50 @@ void tryReplaceUsesOfStageInput( { auto user = use->getUser(); IRBuilder builder(user); - builder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(&builder, user); auto typeAdapter = as(val.impl); - auto materializedInner = materializeValue(&builder, typeAdapter->val); - auto adapted = adaptType( - &builder, - materializedInner, - typeAdapter->pretendType, - typeAdapter->actualType); - if (user->getOp() == kIROp_Load) - { - user->replaceUsesWith(adapted.irValue); - user->removeAndDeallocate(); - } - else + switch (user->getOp()) { - use->set(adapted.irValue); + case kIROp_Load: + { + auto materialized = materializeValue(&builder, val); + user->replaceUsesWith(materialized); + user->removeAndDeallocate(); + } + break; + case kIROp_GetElementPtr: + { + auto targetType = typeAdapter->pretendType; + auto elementType = getElementType(builder, targetType); + SLANG_ASSERT(elementType); + auto subscriptVal = getSubscriptVal( + &builder, + (IRType*)elementType, + val, + user->getOperand(1)); + tryReplaceUsesOfStageInput(context, subscriptVal, user); + } + break; + case kIROp_FieldAddress: + { + auto targetType = as(typeAdapter->pretendType); + SLANG_ASSERT(targetType); + auto subscriptVal = extractField( + &builder, + val, + kMaxUInt, + (IRStructKey*)user->getOperand(1)); + tryReplaceUsesOfStageInput(context, subscriptVal, user); + } + break; + default: + { + auto materialized = materializeValue(&builder, val); + auto tmpVar = builder.emitVar(materialized->getDataType()); + builder.emitStore(tmpVar, materialized); + use->set(tmpVar); + } + break; } }); } @@ -3205,13 +3239,13 @@ void tryReplaceUsesOfStageInput( auto arrayIndexImpl = as(val.impl); auto user = use->getUser(); IRBuilder builder(user); - builder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(&builder, user); auto subscriptVal = getSubscriptVal( &builder, arrayIndexImpl->elementType, arrayIndexImpl->arrayVal, arrayIndexImpl->index); - builder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(&builder, user); auto materializedInner = materializeValue(&builder, subscriptVal); if (user->getOp() == kIROp_Load) { @@ -3220,7 +3254,9 @@ void tryReplaceUsesOfStageInput( } else { - use->set(materializedInner); + auto tmpVar = builder.emitVar(materializedInner->getDataType()); + builder.emitStore(tmpVar, materializedInner); + use->set(tmpVar); } }); break; @@ -3233,6 +3269,9 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); + IRBuilder builder(user); + setInsertBeforeOrdinaryInst(&builder, user); + switch (user->getOp()) { case kIROp_FieldExtract: @@ -3270,10 +3309,20 @@ void tryReplaceUsesOfStageInput( } } break; + case kIROp_GetElementPtr: + { + auto arrayType = as(tupleVal->type); + SLANG_ASSERT(arrayType); + auto subscriptVal = getSubscriptVal( + &builder, + (IRType*)arrayType->getElementType(), + val, + user->getOperand(1)); + tryReplaceUsesOfStageInput(context, subscriptVal, user); + } + break; case kIROp_Load: { - IRBuilder builder(user); - builder.setInsertBefore(user); auto materializedVal = materializeTupleValue(&builder, val); user->replaceUsesWith(materializedVal); user->removeAndDeallocate(); @@ -3449,7 +3498,7 @@ void legalizeEntryPointParameterForGLSL( // Okay, we have a declaration, and we want to modify it! - builder->setInsertBefore(ii); + setInsertBeforeOrdinaryInst(builder, ii); assign(builder, globalOutputVal, ScalarizedVal::value(ii->getOperand(2))); } @@ -3768,12 +3817,13 @@ void legalizeEntryPointParameterForGLSL( blockToMaterialized.tryGetValue(callingBlock, materialized); if (!found) { - replaceBuilder.setInsertBefore(callingBlock->getFirstInst()); + replaceBuilder.setInsertBefore( + callingBlock->getFirstOrdinaryInst()); materialized = materializeValue(&replaceBuilder, globalValue); blockToMaterialized.set(callingBlock, materialized); } - replaceBuilder.setInsertBefore(user); + setInsertBeforeOrdinaryInst(builder, user); auto field = replaceBuilder.emitFieldExtract(globalVarType, materialized, key); replaceBuilder.replaceOperand(operandUse, field); @@ -3888,7 +3938,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module) { rayPayloadCounter++; } - builder.setInsertBefore(inst); + setInsertBeforeOrdinaryInst(&builder, inst); location = builder.getIntValue(builder.getIntType(), rayPayloadCounter); decor->setOperand(0, location); rayPayloadCounter++; @@ -3902,7 +3952,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module) { callablePayloadCounter++; } - builder.setInsertBefore(inst); + setInsertBeforeOrdinaryInst(&builder, inst); location = builder.getIntValue(builder.getIntType(), callablePayloadCounter); decor->setOperand(0, location); callablePayloadCounter++; @@ -3915,7 +3965,7 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module) { hitObjectAttributeCounter++; } - builder.setInsertBefore(inst); + setInsertBeforeOrdinaryInst(&builder, inst); location = builder.getIntValue(builder.getIntType(), hitObjectAttributeCounter); decor->setOperand(0, location); hitObjectAttributeCounter++; diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index b2c216bb4..25d54eb04 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -670,6 +670,10 @@ return { ["SPIRVAsmOperand.__imageType"] = 666, ["SPIRVAsmOperand.__sampledImageType"] = 667, ["Type.CLayout"] = 668, - ["CastUInt64ToDescriptorHandle"] = 669, + ["CastUInt64ToDescriptorHandle"] = 669, ["CastDescriptorHandleToUInt64"] = 670, + ["CastStorageToLogicalBase.CastStorageToLogical"] = 671, + ["CastStorageToLogicalBase.CastStorageToLogicalDeref"] = 672, + ["Decoration.DisableCopyEliminationDecoration"] = 673, + ["Decoration.TempCallArgImmutableVar"] = 674, } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e3014119c..c64f65ccb 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3255,6 +3255,14 @@ struct IRCastFloatToInt : IRInst FIDDLE(leafInst()) }; +FIDDLE() +struct IRCastStorageToLogicalBase : IRInst +{ + FIDDLE(baseInst()) + IRInst* getVal() { return getOperand(0); } + IRInst* getBufferType() { return getOperand(1); } +}; + FIDDLE() struct IRDebugSource : IRInst { @@ -4573,6 +4581,12 @@ public: IRInst* emitCastPtrToInt(IRInst* val); IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val); + IRInst* emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType); + IRCastStorageToLogicalDeref* emitCastStorageToLogicalDeref( + IRType* type, + IRInst* val, + IRInst* bufferType); + IRGlobalConstant* emitGlobalConstant(IRType* type); IRGlobalConstant* emitGlobalConstant(IRType* type, IRInst* val); diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 0ad02b87c..5f54707a1 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -1473,6 +1473,7 @@ local insts = { struct_name = "RequireFullQuadsDecoration", }, }, + { TempCallArgImmutableVar = { struct_name = "TempCallArgImmutableVarDecoration" } }, { TempCallArgVar = { struct_name = "TempCallArgVarDecoration" } }, { nonCopyable = { @@ -1480,6 +1481,7 @@ local insts = { struct_name = "NonCopyableTypeDecoration", }, }, + { DisableCopyEliminationDecoration = {} }, { DynamicUniform = { -- Marks a value to be dynamically uniform. @@ -1891,6 +1893,16 @@ local insts = { { EnumCast = { min_operands = 1 } }, { CastUInt2ToDescriptorHandle = { min_operands = 1 } }, { CastDescriptorHandleToUInt2 = { min_operands = 1 } }, + -- Represents a psuedo cast to convert between a logical type (user declared) and a storage Type + -- (valid in buffer locations). The operand can either be a value or an address. + { + CastStorageToLogicalBase = + { + min_operands = 2, struct_name = "CastStorageToLogicalBase", + { CastStorageToLogical = { min_operands = 2, struct_name = "CastStorageToLogical" } }, + { CastStorageToLogicalDeref = { min_operands = 2, struct_name = "CastStorageToLogicalDeref" } }, + } + }, { CastUInt64ToDescriptorHandle = { min_operands = 1 } }, { CastDescriptorHandleToUInt64 = { min_operands = 1 } }, -- Represents a no-op cast to convert a resource pointer to a resource on targets where the resource handles are diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 085c3d933..27abdeaf0 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2344,6 +2344,10 @@ static LegalVal legalizeCoopMatMapElementIFunc( static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) { + LegalVal legalVal; + if (context->mapValToLegalVal.tryGetValue(inst, legalVal)) + return legalVal; + // Any additional instructions we need to emit // in the process of legalizing `inst` should // by default be insertied right before `inst`. @@ -2463,7 +2467,7 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) auto builder = context->builder; builder->setInsertBefore(inst); - LegalVal legalVal = legalizeInst(context, inst, legalType, legalArgs.getArrayView().arrayView); + legalVal = legalizeInst(context, inst, legalType, legalArgs.getArrayView().arrayView); if (legalVal.flavor == LegalVal::Flavor::simple) { @@ -2789,7 +2793,9 @@ private: static LegalVal legalizeFunc(IRTypeLegalizationContext* context, IRFunc* irFunc) { LegalFuncBuilder builder(context); - return builder.build(irFunc); + auto legalVal = builder.build(irFunc); + registerLegalizedValue(context, irFunc, legalVal); + return legalVal; } static void cloneDecorationToVar(IRInst* srcInst, IRInst* varInst) diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 056ee6244..128502bd8 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -6,6 +6,211 @@ #include "slang-ir-util.h" #include "slang-ir.h" +/// This file implements an important IR transformation pass in the Slang compiler +/// that rewrites buffer element types into valid storage types, a.k.a physical types +/// in SPIRV terminology. +/// +/// Many of our targets have special restrictions on what is allowed to be used as a +/// buffer element. Examples are: +/// - In HLSL and SPIRV, if you have ConstantBuffer, T must be a struct. +/// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside +/// buffers. bool vectors and matrices needs to be lowered into arrays. +/// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have +/// explicit offset. But if it is used to declare a local variable, then it cannot +/// have explicit member offset. This means that we cannot use the same `Foo` struct +/// inside a `StructuredBuffer` and also use it to declare a local variable. +/// +/// We use the terms "physical", "storage", or "lowered" types to refer to types that +/// are legal to use as buffer elements. In contrast, the terms "original" or "logical" +/// refers to types that are declared by the user in its original form. +/// For example, `bool4` is a "logical" type, and its lowered type is `int4`. +/// +/// +/// # Algorithm Overview +/// ---------------------- +/// +/// This pass performs the transformation to create one "storage" type for each type that +/// are used in each kind of buffer. For example, if user defined `Foo`, and used it in +/// `ConstantBuffer` and `StructuredBuffer` and is targeting SPIRV, this pass will +/// create `Foo_std140` and `Foo_std430` types, and update the buffer to be +/// `ConstantBuffer` and `StructuredBuffer`. +/// +/// The pass will rewrite all the code that uses this buffers, and insert translations between +/// Foo_std140/Foo_std430 and Foo to keep types consistent. +/// +/// For example, given: +/// ``` +/// struct Foo { +/// bool4x4 v; +/// } +/// ConstantBuffer cb; +/// bool test(Foo f) { +/// return f.v[0][1]; +/// } +/// void main() { test(cb); } +/// ``` +/// +/// This pass will rewrite it as: +/// ``` +/// struct Foo { +/// bool4x4 v; +/// } +/// struct Foo_std140 { +/// Matrix_bool4x4_std140 v; +/// }; +/// struct Matrix_bool4x4_std140 { +/// int4 values[4]; +/// }; +/// ConstantBuffer cb; +/// bool test_1(Foo_std140 f) { +/// return f.v.values[0][1]; +/// } +/// void main() { test_1(cb); } +/// ``` +/// +/// Note that the one important optimization here is we will defer the translation from +/// storage type to logical type at latest possible time. In the example above, we could +/// have loaded `cb` and then immediately translate it into `Foo` and call `test` with +/// the translated value. However that can lead to code that create unnecessary copies +/// that can't always be removed by the downstream compiler, particulary if there are +/// arrays whose element type needs non-trivial translation. +/// +/// To avoid the performance issue, we will defer this translation until a logical value +/// is actually needed. This is done by pushing the translation to the use sites, and +/// across function call boundaries, specializing any functions being called along the +/// chain. This case, since we are calling `test()` from `main()` with `Foo_std140`, instead +/// of converting the `Foo_std140` to `Foo` before the call, we create a specialization +/// of `test` that accepts `Foo_std140` instead. +/// +/// To enable this interprecedural transformation, the pass is organized as two phases: +/// 1. Create lowered / storage types for all buffer element types, and update +/// global buffer declarations to use storage types. This is implemented in `processModule()` +/// 2. Insert a `CastStorageToLogical(loweredBuffer)` inst, and replace all uses of +/// `loweredBuffer` with the cast inst. This is implemented in `processModule()` +/// 3. Push the `CastStorageToLogical` insts to as late as possible, which means if we see +/// `FieldAddress(CastStorageToLogical(storageAddr), memberKey)`, we should translate +/// it into `CastStorageToLogical(FieldAddress(storageAddr, memberKey)`. +/// If we see a `CastStorageToLogical` inst being used as argument to call a function `f`, +/// specialize `f` to take a pointer to the storage type instead, and insert a +/// `CastStorageToLogical(param)` to convert the param type to logical type at the +/// beginning of the specialized function. (implemented in `deferStorageToLogicalCasts()`) +/// +/// Repeat step 2 and 3 until no more changes can be made, then proceed to step 4. +/// +/// 4. Materialize all remaining `CastStorageToLogical(addr)` by replacing all `load` of such +/// cast insts with `call unpackStorage(addr)`, where `unpackStorage` is a function we +/// synthesize that reads from an address of a storage type and returns a logical type; +/// and replacing all `store(CastStorageToLogical(addr), value)` with `packStorage(addr, value)`, +/// where `packStorage` is a function we synthesis that writes a logical value into a storage +/// addr. This is implemented in `materializeStorageToLogicalCasts()`. +/// +/// That's the main idea of the pass. +/// +/// # Propagating through SSA values +/// +/// Note that `kIROp_CastStorageToLogical` is a pseudo instruction introduced in this pass that +/// has the semantics of "converting a pointer to a storage value into a pointer to a logical +/// value". A dual of this inst is `kIROp_CastStorageToLogicalDeref`, which has an additional +/// builtin "load" semantic. That is, given `Ptr addr`, `CastStorageToLogical(addr)` +/// will have type `Ptr`, and `CastStorageToLogicalDeref(addr)` will have type +/// `LogicalType`. In other words, `CastStorageToLogicalDeref(addr)` is equivalent to +/// `load(CastStorageToLogical(addr))`. +/// +/// The `CastStorageToLogicalDeref` pseudo inst is needed to push defer through `load`s. +/// Consider the following example: +/// ``` +/// ptr : StorageType* = ... +/// lptr : LogicalType* = CastStorageToLogical(ptr); +/// l = load(lptr) +/// m = fieldExtract(l, member) +/// call f, m +/// ``` +/// In this case, only l.member is used, so we should avoid translating other unrelated members +/// from storage type to logical type. To achieve this we must be able to push the +/// `CastStorageToLogical` operation beyond the `load`. The steps to achieve this are: +/// 1. we process `lptr` inst by inspecting its users. We find that a `load` (l) uses it. +/// 2. replace the `load` with `CastStorageToLogicalDeref(ptr)`, the IR become: +/// ``` +/// ptr : StorageType* = ... +/// l_1 = CastStorageToLogicalDeref(ptr); +/// m = fieldExtract(l_1, member); +/// call f, m +/// ``` +/// 3. push the new `l_1` inst to worklist, and when it gets processed, we continue to inspect +/// its users, and find that it is being used by `fieldExtract`. We will rewrite the +/// `fieldExtract` into `CastStorageToLogicalDeref(fieldAddr(ptr, member))`, and the IR become: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogicalDeref(m_ptr); +/// call f, m +/// ``` +/// 4. Since there are no more uses of `m` that can be translated, stop. Note that it is possible +/// to continue specializing `f` and replace its first parameter's type to storage type. However +/// this implementation currently does not specialize functions whose parameter type is not a +/// pointer/reference type. When we target SPIRV, we will already be running the +/// `transformParamsToConstRef` pass that would have converted `f` to take in `ConstRef`. +/// In this case, the initial IR would be in the form of +/// ``` +/// ptr : StorageType* = ... +/// lptr : LogicalType* = CastStorageToLogical(ptr); +/// l = load(lptr) +/// m = fieldExtract(l, member) +/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]] +/// store tmpVar, m +/// call f, tmpVar +/// ``` +/// To allow us to remove the `tmpVar` store introduced during `transformParamsToConstRef`, +/// this pass also handles the propagation through temp var stores. After pushing the cast +/// through `m`, we will get IR to this form: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogicalDeref(m_ptr); +/// var tmpVar : MemberLogiocalType [[ImmutableTempVar]] +/// store tmpVar, m +/// call f, tmpVar +/// ``` +/// This time, we will see that `m` is being used by a `store` into a `[[ImmutableTempVar]]` var, +/// and we can safely replace all uses of `tmpVar` to `m_ptr`, and therefore the IR will become: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// m = CastStorageToLogical(m_ptr); +/// call f, m_ptr +/// ``` +/// Now, we are in the case where a `CastStorageToLogical` is used as argument in a `call`. +/// This will trigger our function specialization rule to create `f_1` that accepets a +/// `StorageMember*`, and we will rewrite the IR again to: +/// ``` +/// ptr : StorageType* = ... +/// m_ptr = FieldAddr(ptr, member) +/// call f_1, m_ptr +/// ``` +/// +/// # Trailing Pointer Rewrite +/// +/// Another transformation done in this pass is it also rewrites struct with unsized trailing +/// arrays. Since an unsized type isn't a physical type and cannot be used as a pointee type, +/// we will have problem translating the following code to SPIRV: +/// ``` +/// struct Foo { int count; int[] values; } +/// uniform Foo* b; +/// ``` +/// +/// When we create a storage type for `Foo`, we will define it as: +/// ``` +/// struct Foo_std430 { int count; } +/// ``` +/// Where we removed the trailing array. +/// This makes `Foo_std430` an ordinary sized type that can be used freely as pointee type +/// in SPIRV. +/// +/// However this does mean that we also need to translate things like `ptr->values[2]` +/// into `((int*)(ptr+1))[2]`. Which we also handle during step 2 of the algorithm. +/// (`maybeTranslateTrailingPointerGetElementAddress`) +/// + namespace Slang { @@ -85,7 +290,7 @@ struct LoweredElementTypeContext else { auto val = builder.emitIntrinsicInst( - tryGetPointedToType(&builder, dest->getDataType()), + tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), op, 1, &operand); @@ -145,6 +350,22 @@ struct LoweredElementTypeContext TargetProgram* target; BufferElementTypeLoweringOptions options; + struct SpecializationKey + { + IRFunc* callee; + IRFuncType* specializedFuncType; + bool operator==(const SpecializationKey& other) const + { + return (callee == other.callee && specializedFuncType == other.specializedFuncType); + } + HashCode64 getHashCode() const + { + return combineHash(Slang::getHashCode(callee), Slang::getHashCode(specializedFuncType)); + } + }; + // Specialized functions that takes storage-typed pointers instead of logical-typed pointers. + Dictionary specializedFuncs; + LoweredElementTypeContext( TargetProgram* target, BufferElementTypeLoweringOptions inOptions, @@ -881,17 +1102,25 @@ struct LoweredElementTypeContext IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) { - if (as(originalPtrLikeType) || as(originalPtrLikeType) || + IRBuilder builder(newElementType); + builder.setInsertAfter(newElementType); + if (auto ptrType = as(originalPtrLikeType)) + { + return builder.getPtrType(newElementType, ptrType); + } + + if (as(originalPtrLikeType) || as(originalPtrLikeType) || as(originalPtrLikeType)) { - IRBuilder builder(newElementType); - builder.setInsertAfter(newElementType); ShortList operands; - for (UInt i = 0; i < originalPtrLikeType->getOperandCount(); i++) + operands.add(newElementType); + for (UInt i = 1; i < originalPtrLikeType->getOperandCount(); i++) + { operands.add(originalPtrLikeType->getOperand(i)); - operands[0] = newElementType; - return builder.getType( + } + return (IRType*)builder.emitIntrinsicInst( + builder.getTypeKind(), originalPtrLikeType->getOp(), (UInt)operands.getCount(), operands.getArrayView().getBuffer()); @@ -914,26 +1143,544 @@ struct LoweredElementTypeContext TypeLoweringConfig config; }; - IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst) + IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst, IRInst* baseAddr) { switch (loadStoreInst->getOp()) { case kIROp_Load: case kIROp_Store: - return loadStoreInst->getOperand(0); + return baseAddr; case kIROp_StructuredBufferLoad: case kIROp_StructuredBufferLoadStatus: case kIROp_RWStructuredBufferLoad: case kIROp_RWStructuredBufferLoadStatus: case kIROp_RWStructuredBufferStore: return builder.emitRWStructuredBufferGetElementPtr( - loadStoreInst->getOperand(0), + baseAddr, loadStoreInst->getOperand(1)); default: return nullptr; } } + bool maybeTranslateTrailingPointerGetElementAddress( + IRBuilder& builder, + IRFieldAddress* fieldAddr, + IRCastStorageToLogicalBase* castInst, + TypeLoweringConfig& config, + List& castInstWorkList) + { + // If we are accessing an unsized array element from a pointer, we need to + // compute + // the trailing ptr that points to the first element of the array. + // And then replace all getElementPtr(arrayPtr, index) with + // getOffsetPtr(trailingPtr, index). + + auto ptrType = as(fieldAddr->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return false; + if (auto unsizedArrayType = as(ptrType->getValueType())) + { + builder.setInsertBefore(fieldAddr); + auto newArrayPtrVal = fieldAddr->getBase(); + auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), config); + + IRSizeAndAlignment arrayElementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerType.loweredType, + &arrayElementSizeAlignment); + IRSizeAndAlignment baseSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()), + &baseSizeAlignment); + + // Convert pointer to uint64 and adjust offset. + IRIntegerValue offset = baseSizeAlignment.size; + offset = align(offset, arrayElementSizeAlignment.alignment); + if (offset != 0) + { + auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); + newArrayPtrVal = builder.emitAdd( + rawPtr->getFullType(), + rawPtr, + builder.getIntValue(builder.getUInt64Type(), offset)); + } + newArrayPtrVal = builder.emitBitCast( + builder.getPtrType(loweredInnerType.loweredType, ptrType), + newArrayPtrVal); + traverseUses( + fieldAddr, + [&](IRUse* fieldAddrUse) + { + auto fieldAddrUser = fieldAddrUse->getUser(); + if (fieldAddrUser->getOp() == kIROp_GetElementPtr) + { + builder.setInsertBefore(fieldAddrUser); + auto newElementPtr = + builder.emitGetOffsetPtr(newArrayPtrVal, fieldAddrUser->getOperand(1)); + auto castedGEP = builder.emitCastStorageToLogical( + fieldAddrUser->getFullType(), + newElementPtr, + castInst->getBufferType()); + fieldAddrUser->replaceUsesWith(castedGEP); + fieldAddrUser->removeAndDeallocate(); + if (auto castStorage = as(castedGEP)) + castInstWorkList.add(castStorage); + } + else if (fieldAddrUser->getOp() == kIROp_GetOffsetPtr) + { + } + else + { + SLANG_UNEXPECTED("unknown use of pointer to unsized array."); + } + }); + SLANG_ASSERT(!fieldAddr->hasUses()); + fieldAddr->removeAndDeallocate(); + return true; + } + return false; + } + + + // Helper function to discover all `call`s in `func` that has at least one argument + // that is `CastStorageToPhysical`. + void discoverCallsToProcess(List& callWorkList, IRFunc* func) + { + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + auto call = as(inst); + if (!call) + continue; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (arg->getOp() == kIROp_CastStorageToLogical) + { + callWorkList.add(call); + break; + } + } + } + } + } + + void deferStorageToLogicalCasts( + IRModule* module, + List castInstWorkList) + { + IRBuilder builder(module); + + while (castInstWorkList.getCount()) + { + // We process call instructions after other instructions, so we + // can be sure that all castStorageToLogical insts have already + // been pushed to the call argument lists before we process it. + HashSet callWorkListSet; + // Defer the storage-to-logical cast operation to latest possible time to avoid + // unnecessary packing/unpacking. + for (Index i = 0; i < castInstWorkList.getCount(); i++) + { + auto castInst = castInstWorkList[i]; + auto ptrVal = castInst->getOperand(0); + auto config = + getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType()); + traverseUses( + castInst, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_FieldAddress: + if (!isUseBaseAddrOperand(use, user)) + break; + // If our logical struct type ends with an unsized array field, the + // storage struct type won't have this field defined. + // Therefore, all fieldAddress(obj, lastField) inst retrieving the last + // field of such struct should be translated into + // `(ArrayElementType*)((StorageStruct*)(obj)+1) + idx`. + // That is, we should first compute the tailing pointer of the + // struct, and replace all getElementPtr(fieldAddr, idx) with + // getOffsetPtr(tailingPtr, idx). + if (maybeTranslateTrailingPointerGetElementAddress( + builder, + (IRFieldAddress*)user, + castInst, + config, + castInstWorkList)) + return; + [[fallthrough]]; + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_RWStructuredBufferGetElementPtr: + { + // gep(castStorageToLogical(x)) ==> castStorageToLogical(gep(x)) + if (!isUseBaseAddrOperand(use, user)) + break; + auto logicalBaseType = castInst->getDataType(); + auto logicalType = user->getDataType(); + IRInst* storageBaseAddr = ptrVal; + auto originalBaseValueType = + tryGetPointedToOrBufferElementType(&builder, logicalBaseType); + if (user->getOp() == kIROp_GetElementPtr) + { + // If original type is an array, the lowered type will be a + // struct. In that case, all existing address insts should be + // appended with a field extract. + if (as(originalBaseValueType)) + { + auto arrayLowerInfo = + getLoweredTypeInfo(originalBaseValueType, config); + if (arrayLowerInfo.loweredInnerArrayType) + { + builder.setInsertBefore(user); + List args; + for (UInt i = 0; i < user->getOperandCount(); i++) + args.add(user->getOperand(i)); + storageBaseAddr = builder.emitFieldAddress( + builder.getPtrType( + arrayLowerInfo.loweredInnerArrayType), + ptrVal, + arrayLowerInfo.loweredInnerStructKey); + } + } + if (as(originalBaseValueType)) + { + // We are tring to get a pointer to a lowered matrix + // element. We process this insts at a later phase. + SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); + lowerMatrixAddresses( + module, + MatrixAddrWorkItem{user, config}); + break; + } + } + + + builder.setInsertBefore(user); + IRInst* storageGEP = nullptr; + switch (user->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + // For standard gep instructions, use the + // IR builder to auto-deduce result type + // of the new GEP inst. + ShortList newArgs; + for (UInt i = 1; i < user->getOperandCount(); i++) + newArgs.add(user->getOperand(i)); + storageGEP = builder.emitElementAddress( + storageBaseAddr, + newArgs.getArrayView().arrayView); + break; + } + default: + { + // For non-standard gep instructions, e.g. + // RWStructuredBufferGetElementPtr, + // manually create the inst here. + ShortList newArgs; + newArgs.add(storageBaseAddr); + for (UInt i = 1; i < user->getOperandCount(); i++) + newArgs.add(user->getOperand(i)); + auto logicalValueType = tryGetPointedToOrBufferElementType( + &builder, + logicalType); + auto storageTypeInfo = + getLoweredTypeInfo(logicalValueType, config); + storageGEP = builder.emitIntrinsicInst( + builder.getPtrType(storageTypeInfo.loweredType), + user->getOp(), + newArgs.getCount(), + newArgs.getArrayView().getBuffer()); + break; + } + } + auto castOfGEP = builder.emitCastStorageToLogical( + logicalType, + storageGEP, + castInst->getBufferType()); + user->replaceUsesWith(castOfGEP); + user->removeAndDeallocate(); + if (auto castStorage = as(castOfGEP)) + castInstWorkList.add(castStorage); + break; + } + case kIROp_Call: + { + // call(f, castStorageToLogical(x)) ==> call(f', x) + // + // If we see a call that takes a logical typed pointer, we will + // specialize the callee to take a storage typed pointer instead, + // and push the cast to inside the callee. + // We will process calls after other gep insts, so for now just add + // it into a separate worklist. + if (castInst->getOp() == kIROp_CastStorageToLogical) + { + callWorkListSet.add((IRCall*)user); + } + break; + } + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_StructuredBufferConsume: + { + // If we see a load(CastStorageToLogical(storageAddr)), + // then based on what `storageAddr` is, we will push down + // the cast differently. + // - If `storageAddr` is already a tempVar that we introduced to + // hold the value of a buffer resource load, we can simply + // convert this into `CastStorageToLogicalDeref(storageAddr)`. + // - Otherwise, if `storageAddr` is a buffer location, we will + // create a temp var to hold the result of the memory load, + // Then we create a `CastStorageToLogicalDeref(tempVar)` + // structure and use it to replace `user`. + // Note that it is important to introduce a temp var and preserve + // the buffer load operation, so we are not changing the memory + // semantics of the original program. + if (!isUseBaseAddrOperand(use, user)) + break; + // If loaded value is itself a pointer or buffer, + // stop pushing the cast along the resulting address. + // we will handle loads from the pointer separately. + if (as(user->getDataType()) || + as(user->getDataType()) || + as(user->getDataType())) + break; + // Don't push the cast beyond the load if we are already + // a simple type. + if (!isCompositeType(user->getDataType())) + break; + builder.setInsertBefore(user); + IRCloneEnv cloneEnv; + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setOperand(0, ptrVal); + auto elementStorageType = tryGetPointedToOrBufferElementType( + &builder, + ptrVal->getDataType()); + newLoad->setFullType(elementStorageType); + IRInst* tempVar = nullptr; + if (as(user)) + { + auto rootAddr = getRootAddr(ptrVal); + if (rootAddr->findDecorationImpl( + kIROp_TempCallArgImmutableVarDecoration)) + tempVar = ptrVal; + } + if (!tempVar) + { + tempVar = builder.emitVar(elementStorageType); + builder.addDecoration( + tempVar, + kIROp_TempCallArgImmutableVarDecoration); + builder.emitStore(tempVar, newLoad); + } + auto newCast = builder.emitCastStorageToLogicalDeref( + user->getFullType(), + tempVar, + castInst->getBufferType()); + user->replaceUsesWith(newCast); + user->removeAndDeallocate(); + castInstWorkList.add(newCast); + break; + } + case kIROp_FieldExtract: + case kIROp_GetElement: + { + if (!isUseBaseAddrOperand(use, user)) + break; + // elementExtract(castStorageToLogicalDeref(addr), key) + // ==> load(gep(castStorageToLogical(addr), key) + builder.setInsertBefore(user); + auto castAddr = builder.emitCastStorageToLogical( + builder.getPtrType(castInst->getDataType()), + ptrVal, + castInst->getBufferType()); + IRInst* gep = nullptr; + if (user->getOp() == kIROp_GetElement) + gep = builder.emitElementAddress(castAddr, user->getOperand(1)); + else + gep = builder.emitFieldAddress(castAddr, user->getOperand(1)); + auto load = builder.emitLoad(gep); + user->replaceUsesWith(load); + user->removeAndDeallocate(); + if (auto castStorage = as(castAddr)) + castInstWorkList.add(castStorage); + break; + } + case kIROp_Store: + { + // If we see `store(tempVar, castStorageToLogicalDeref(addr))`, + // replace `tempVar` with `castStorageToLogical(addr)`. + if (castInst->getOp() != kIROp_CastStorageToLogicalDeref) + break; + auto store = as(user); + if (store->getVal() != castInst) + break; + auto dest = store->getPtr(); + if (!dest->findDecorationImpl( + kIROp_TempCallArgImmutableVarDecoration)) + break; + builder.setInsertBefore(user); + auto castAddr = builder.emitCastStorageToLogical( + builder.getPtrType(castInst->getDataType()), + ptrVal, + castInst->getBufferType()); + dest->replaceUsesWith(castAddr); + dest->removeAndDeallocate(); + if (auto castStorage = as(castAddr)) + castInstWorkList.add(castStorage); + break; + } + } + }); + } + + // Now that we have processed all GEP instructions, we can now proceed to + // process all calls. This is done by making a clone of the callee, and change + // the parameter type from logical type to storage type, and insert a + // castStorageToLogical on the parameter. Then we go back to the beginning and make sure + // we process those newly created castStorageToLogical insts. + List newCasts; + List callWorkList; + for (auto call : callWorkListSet) + callWorkList.add(call); + for (Index c = 0; c < callWorkList.getCount(); c++) + { + auto call = callWorkList[c]; + auto calleeFunc = as(call->getCallee()); + // We compute the func type for the specialized func based on the arguments + // provided, and check the specialization cache to reuse existing specialization + // when possible. + List oldParams; + for (auto param : calleeFunc->getParams()) + oldParams.add(param); + SLANG_ASSERT(oldParams.getCount() == (Index)call->getArgCount()); + + ShortList paramTypes; + ShortList newArgs; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (auto castArg = as(arg)) + { + auto oldParamPtrType = oldParams[i]->getDataType(); + auto storageValueType = tryGetPointedToOrBufferElementType( + &builder, + castArg->getOperand(0)->getDataType()); + auto storagePtrType = + getLoweredPtrLikeType(oldParamPtrType, storageValueType); + paramTypes.add(storagePtrType); + newArgs.add(castArg->getOperand(0)); + } + else + { + paramTypes.add(arg->getDataType()); + newArgs.add(arg); + } + } + auto specializedFuncType = builder.getFuncType( + (UInt)paramTypes.getCount(), + paramTypes.getArrayView().getBuffer(), + call->getDataType()); + auto key = SpecializationKey{(IRFunc*)calleeFunc, specializedFuncType}; + IRFunc* specializedFunc = nullptr; + if (!specializedFuncs.tryGetValue(key, specializedFunc)) + { + specializedFunc = createSpecializedFuncThatUseStorageType( + call, + specializedFuncType, + newCasts); + specializedFuncs[key] = specializedFunc; + + // The cloned function may also contain `call`s with + // `CastStorageToLogical` arguments, and we want to add + // thoses calls to the callWorkList for further processing. + discoverCallsToProcess(callWorkList, specializedFunc); + } + builder.setInsertBefore(call); + auto newCall = builder.emitCallInst( + call->getFullType(), + specializedFunc, + newArgs.getArrayView().arrayView); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + + // Remove any casts that have no more uses. + for (auto cast : castInstWorkList) + { + if (!cast->hasUses()) + cast->removeAndDeallocate(); + } + + // Continue to process new casts added during function specialization. + castInstWorkList.swapWith(newCasts); + } + } + + IRFunc* createSpecializedFuncThatUseStorageType( + IRCall* call, + IRFuncType* specializedFuncType, + List& outNewCasts) + { + IRBuilder builder(call); + builder.setInsertBefore(call->getCallee()); + + // Create a clone of the callee. + IRCloneEnv cloneEnv; + auto clonedFunc = as(cloneInst(&cloneEnv, &builder, call->getCallee())); + List uses; + + // If a parameter is being translated to storage type, + // insert a cast to convert it to logical type. + List params; + for (auto param : clonedFunc->getParams()) + params.add(param); + for (UInt i = 0; i < (UInt)params.getCount(); i++) + { + auto param = params[i]; + SLANG_RELEASE_ASSERT(i < call->getArgCount()); + auto arg = call->getArg(i); + auto cast = as(arg); + if (!cast) + continue; + auto logicalParamType = param->getFullType(); + auto storageType = specializedFuncType->getParamType(i); + param->setFullType((IRType*)storageType); + setInsertAfterOrdinaryInst(&builder, param); + + // Store uses of param before creating a cast inst that uses it. + uses.clear(); + for (auto use = param->firstUse; use; use = use->nextUse) + uses.add(use); + auto castedParam = + builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType()); + if (auto castStorage = as(castedParam)) + outNewCasts.add(castStorage); + + // Replace all previous uses of param to use castedParam instead. + for (auto use : uses) + builder.replaceOperand(use, castedParam); + } + clonedFunc->setFullType(specializedFuncType); + removeLinkageDecorations(clonedFunc); + return clonedFunc; + } + void processModule(IRModule* module) { IRBuilder builder(module); @@ -941,6 +1688,7 @@ struct LoweredElementTypeContext { IRType* bufferType; IRType* elementType; + IRType* loweredBufferType = nullptr; bool shouldWrapArrayInStruct = false; }; List bufferTypeInsts; @@ -990,12 +1738,10 @@ struct LoweredElementTypeContext bufferTypeInsts.add(BufferTypeInfo{(IRType*)globalInst, elementType}); } - // Maintain a pending work list of all matrix addresses, and try to lower them out of - // existance after everything else has been lowered. - List matrixAddrInsts; + List castInstWorkList; - for (auto bufferTypeInfo : bufferTypeInsts) + for (auto& bufferTypeInfo : bufferTypeInsts) { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; @@ -1022,10 +1768,10 @@ struct LoweredElementTypeContext (UInt)typeOperands.getCount(), typeOperands.getArrayView().getBuffer()); - // We treat a value of a buffer type as a pointer, and use a work list to translate - // all loads and stores through the pointer values that needs lowering. + // Replace all global buffer declarations to use the storage type instead, + // and insert initial `castStorageToLogical` instructions to convert the + // storage-typed pointer to logical-typed pointer. - List ptrValsWorkList; traverseUses( bufferType, [&](IRUse* use) @@ -1033,433 +1779,400 @@ struct LoweredElementTypeContext auto user = use->getUser(); if (use != &user->typeUse) return; - ptrValsWorkList.add(use->getUser()); + // We don't want to insert cast instructions for uses of + // intermediate address instruction that are themselves + // derived from some other base address. We will let + // the later part of the pass to systematically propagate + // the cast through them. + switch (user->getOp()) + { + case kIROp_FieldAddress: + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_RWStructuredBufferGetElementPtr: + return; + } + auto ptrVal = use->getUser(); + setInsertAfterOrdinaryInst(&builder, ptrVal); + builder.replaceOperand(use, loweredBufferType); + auto logicalBufferType = getLoweredPtrLikeType(bufferType, elementType); + auto castStorageToLogical = + builder.emitCastStorageToLogical(logicalBufferType, ptrVal, bufferType); + traverseUses( + ptrVal, + [&](IRUse* ptrUse) + { + if (ptrUse->getUser() != castStorageToLogical) + builder.replaceOperand(ptrUse, castStorageToLogical); + }); + if (auto castStorage = as(castStorageToLogical)) + castInstWorkList.add(castStorage); }); + bufferTypeInfo.loweredBufferType = loweredBufferType; + } + + // Push down `CastStorageToLogical` insts we inserted above to latest possible locations, + // specializing all function calls along the way, until we truly need the the logical value. + // This means that `FieldAddr(CastStorageToLogical(buffer), field0))` is translated to + // `CastStorageToLogical(FieldAddr(buffer, field0))`. This way we can be sure that we are + // doing minimal packing/unpacking. + deferStorageToLogicalCasts(module, _Move(castInstWorkList)); + + // Now translate the `CastStorageToLogical` into actual packing/unpacking code. + materializeStorageToLogicalCasts(module->getModuleInst()); + + // Replace all remaining uses of bufferType to loweredBufferType, these uses are + // non-operational and should be directly replaceable, such as uses in `IRFuncType`. + for (auto bufferTypeInst : bufferTypeInsts) + { + if (!bufferTypeInst.loweredBufferType) + continue; + bufferTypeInst.bufferType->replaceUsesWith(bufferTypeInst.loweredBufferType); + bufferTypeInst.bufferType->removeAndDeallocate(); + } + } - // Translate the values to use new lowered buffer type instead. - for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + void materializeStorageToLogicalCastsImpl(IRCastStorageToLogicalBase* castInst) + { + IRBuilder builder(castInst); + if (!castInst->hasUses()) + { + castInst->removeAndDeallocate(); + return; + } + if (castInst->getOp() == kIROp_CastStorageToLogicalDeref) + { + // Convert CastStorageToLogicalDeref to load(CastStorageToLogical) to reuse + // the same materialization logic for CastStorageToLogical. + // + builder.setInsertBefore(castInst); + auto ptrType = builder.getPtrType(castInst->getDataType()); + auto castPtr = builder.emitCastStorageToLogical( + (IRType*)ptrType, + castInst->getVal(), + castInst->getBufferType()); + auto load = builder.emitLoad(castPtr); + castInst->replaceUsesWith(load); + castInst->removeAndDeallocate(); + if (auto castStorage = as(castPtr)) + materializeStorageToLogicalCastsImpl(castStorage); + return; + } + + // Translate the values to use new lowered buffer type instead. + + auto ptrVal = castInst->getOperand(0); + auto oldPtrType = castInst->getFullType(); + auto originalElementType = oldPtrType->getOperand(0); + auto config = getTypeLoweringConfigForBuffer(target, (IRType*)castInst->getBufferType()); + + + LoweredElementTypeInfo loweredElementTypeInfo = {}; + if (auto getElementPtr = as(ptrVal)) + { + if (auto arrayType = as(tryGetPointedToOrBufferElementType( + &builder, + getElementPtr->getBase()->getDataType()))) { - auto ptrVal = ptrValsWorkList[i]; - auto oldPtrType = ptrVal->getFullType(); - auto originalElementType = oldPtrType->getOperand(0); - - // If we are accessing an unsized array element from a pointer, we need to compute - // the trailing ptr that points to the first element of the array. - // And then replace all getElementPtr(arrayPtr, index) with - // getOffsetPtr(trailingPtr, index). - if (auto fieldAddr = as(ptrVal)) + // For WGSL, an array of scalar or vector type will always be converted to + // an array of 16-byte aligned vector type. In this case, we will run into a + // GetElementPtr where the result type is different from the element type of + // the base array. + // We should setup loweredElementTypeInfo so the remaining logic can handle + // this case and insert proper packing/unpacking logic around it. + if (arrayType->getElementType() != originalElementType && + isScalarOrVectorType(originalElementType)) { - auto handleUnsizedArrayAccess = [&]() -> bool - { - auto ptrType = as(ptrVal->getDataType()); - if (!ptrType) - return false; - if (ptrType->getAddressSpace() != AddressSpace::UserPointer) - return false; - if (auto unsizedArrayType = as(ptrType->getValueType())) - { - builder.setInsertBefore(ptrVal); - auto newArrayPtrVal = fieldAddr->getBase(); - auto loweredInnerType = - getLoweredTypeInfo(unsizedArrayType->getElementType(), config); + loweredElementTypeInfo.loweredType = arrayType->getElementType(); + loweredElementTypeInfo.originalType = (IRType*)originalElementType; + loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod( + loweredElementTypeInfo.originalType, + loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod( + loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.originalType); + } + } + } - IRSizeAndAlignment arrayElementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - loweredInnerType.loweredType, - &arrayElementSizeAlignment); - IRSizeAndAlignment baseSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), - &baseSizeAlignment); + // For general cases we simply check if the element type needs lowering. + // If so we will insert packing/unpacking logic if necessary. + // + if (!loweredElementTypeInfo.loweredType) + { + loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, config); + } - // Convert pointer to uint64 and adjust offset. - IRIntegerValue offset = baseSizeAlignment.size; - offset = align(offset, arrayElementSizeAlignment.alignment); - if (offset != 0) - { - auto rawPtr = - builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); - newArrayPtrVal = builder.emitAdd( - rawPtr->getFullType(), - rawPtr, - builder.getIntValue(builder.getUInt64Type(), offset)); - } - newArrayPtrVal = builder.emitBitCast( - builder.getPtrType( - loweredInnerType.loweredType, - ptrType->getAddressSpace()), - newArrayPtrVal); - traverseUses( - ptrVal, - [&](IRUse* use) - { - auto user = use->getUser(); - if (user->getOp() == kIROp_GetElementPtr) - { - builder.setInsertBefore(user); - auto newElementPtr = builder.emitGetOffsetPtr( - newArrayPtrVal, - user->getOperand(1)); - user->replaceUsesWith(newElementPtr); - user->removeAndDeallocate(); - ptrValsWorkList.add(newElementPtr); - } - else if (user->getOp() == kIROp_GetOffsetPtr) - { - } - else - { - SLANG_UNEXPECTED( - "unknown use of pointer to unsized array."); - } - }); - SLANG_ASSERT(!ptrVal->hasUses()); - ptrVal->removeAndDeallocate(); - return true; - } - return false; - }; - if (handleUnsizedArrayAccess()) - continue; - } + if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) + { + castInst->replaceUsesWith(ptrVal); + castInst->removeAndDeallocate(); + return; + } - LoweredElementTypeInfo loweredElementTypeInfo = {}; - if (auto getElementPtr = as(ptrVal)) + traverseUses( + castInst, + [&](IRUse* use) + { + auto user = use->getUser(); + if (as(user)) + return; + switch (user->getOp()) { - if (auto arrayType = as( - tryGetPointedToType(&builder, getElementPtr->getBase()->getDataType()))) + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_StructuredBufferConsume: { - // For WGSL, an array of scalar or vector type will always be converted to - // an array of 16-byte aligned vector type. In this case, we will run into a - // GetElementPtr where the result type is different from the element type of - // the base array. - // We should setup loweredElementTypeInfo so the remaining logic can handle - // this case and insert proper packing/unpacking logic around it. - if (arrayType->getElementType() != originalElementType && - isScalarOrVectorType(originalElementType)) + if (castInst != user->getOperand(0)) + break; + builder.setInsertBefore(user); + auto addr = getBufferAddr(builder, user, ptrVal); + if (!addr) + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + builder.emitStore(addr, newLoad); + } + if (auto alignedAttr = user->findAttr()) { - loweredElementTypeInfo.loweredType = arrayType->getElementType(); - loweredElementTypeInfo.originalType = (IRType*)originalElementType; - loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod( - loweredElementTypeInfo.originalType, - loweredElementTypeInfo.loweredType); - loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod( - loweredElementTypeInfo.loweredType, - loweredElementTypeInfo.originalType); + builder.addAlignedAddressDecoration(addr, alignedAttr->getAlignment()); } + auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( + builder, + loweredElementTypeInfo.originalType, + addr); + user->replaceUsesWith(unpackedVal); + user->removeAndDeallocate(); + return; } - } - - // For general cases we simply check if the element type needs lowering. - // If so we will insert packing/unpacking logic if necessary. - // - if (!loweredElementTypeInfo.loweredType) - { - loweredElementTypeInfo = - getLoweredTypeInfo((IRType*)originalElementType, config); - } - - if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) - continue; - - ptrVal->setFullType(getLoweredPtrLikeType( - ptrVal->getFullType(), - loweredElementTypeInfo.loweredType)); - - traverseUses( - ptrVal, - [&](IRUse* use) + case kIROp_Store: + case kIROp_RWStructuredBufferStore: + case kIROp_StructuredBufferAppend: { - auto user = use->getUser(); - if (as(user)) - return; - switch (user->getOp()) + // Use must be the dest operand of the store inst. + if (use != user->getOperands() + 0) + break; + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto originalVal = getStoreVal(user); + if (auto sbAppend = as(user)) { - case kIROp_Load: - case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferLoadStatus: - case kIROp_RWStructuredBufferLoad: - case kIROp_RWStructuredBufferLoadStatus: - case kIROp_StructuredBufferConsume: + builder.setInsertBefore(sbAppend); + IRInst* addr = nullptr; + if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref) { - builder.setInsertBefore(user); - auto addr = getBufferAddr(builder, user); - if (!addr) - { - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto newLoad = cloneInst(&cloneEnv, &builder, user); - newLoad->setFullType(loweredElementTypeInfo.loweredType); - addr = builder.emitVar(loweredElementTypeInfo.loweredType); - builder.emitStore(addr, newLoad); - } - if (auto alignedAttr = user->findAttr()) - { - builder.addAlignedAddressDecoration( - addr, - alignedAttr->getAlignment()); - } - auto unpackedVal = - loweredElementTypeInfo.convertLoweredToOriginal.apply( - builder, - loweredElementTypeInfo.originalType, - addr); - user->replaceUsesWith(unpackedVal); - user->removeAndDeallocate(); - break; + addr = originalVal->getOperand(0); } - case kIROp_Store: - case kIROp_RWStructuredBufferStore: - case kIROp_StructuredBufferAppend: + else { - // Use must be the dest operand of the store inst. - if (use != user->getOperands() + 0) - break; - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto originalVal = getStoreVal(user); - IRInst* addr = getBufferAddr(builder, user); - if (addr) - { - if (auto alignedAttr = user->findAttr()) - { - builder.addAlignedAddressDecoration( - addr, - alignedAttr->getAlignment()); - } - - loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, addr, originalVal); - user->removeAndDeallocate(); - } - else if (auto sbAppend = as(user)) - { - builder.setInsertBefore(sbAppend); - addr = builder.emitVar(loweredElementTypeInfo.loweredType); - loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, addr, originalVal); - auto packedVal = builder.emitLoad(addr); - sbAppend->setOperand(1, packedVal); - } - else - { - SLANG_UNREACHABLE("unhandled store type"); - } - break; + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); } - case kIROp_GetElementPtr: - case kIROp_FieldAddress: + auto packedVal = builder.emitLoad(addr); + sbAppend->setOperand(1, packedVal); + } + else + { + IRInst* addr = getBufferAddr(builder, user, ptrVal); + if (auto alignedAttr = user->findAttr()) { - // If original type is an array, the lowered type will be a struct. - // In that case, all existing address insts should be appended with - // a field extract. - if (as(originalElementType)) - { - builder.setInsertBefore(user); - List args; - for (UInt i = 0; i < user->getOperandCount(); i++) - args.add(user->getOperand(i)); - auto newArrayPtrVal = builder.emitFieldAddress( - builder.getPtrType( - loweredElementTypeInfo.loweredInnerArrayType), - ptrVal, - loweredElementTypeInfo.loweredInnerStructKey); - builder.replaceOperand(use, newArrayPtrVal); - ptrValsWorkList.add(user); - } - else if (as(originalElementType)) - { - // We are tring to get a pointer to a lowered matrix element. - // We process this insts at a later phase. - SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); - matrixAddrInsts.add(MatrixAddrWorkItem{user, config}); - } - else - { - // If we getting a derived address from the pointer, we need - // to recursively lower the new address. We do so by pushing - // the address inst into the work list. - ptrValsWorkList.add(user); - } + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); } - break; - case kIROp_RWStructuredBufferGetElementPtr: - case kIROp_GetOffsetPtr: - ptrValsWorkList.add(user); - break; - case kIROp_StructuredBufferGetDimensions: - break; - case kIROp_Call: + if (originalVal->getOp() == kIROp_CastStorageToLogicalDeref) + { + auto valAddr = originalVal->getOperand(0); + auto storageVal = builder.emitLoad(valAddr); + builder.emitStore(addr, storageVal); + } + else { - // If a structured buffer or pointer typed value is used directly as - // an argument, we don't need to do any marshalling here. - if (as(ptrVal->getDataType())) - break; - if (options.lowerBufferPointer && - as(ptrVal->getDataType())) - break; - // If we are calling a function with an l-value pointer from buffer - // access, we need to materialize the object as a local variable, - // and pass the address of the local variable to the function. - builder.setInsertBefore(user); - auto unpackedVal = - loweredElementTypeInfo.convertLoweredToOriginal.apply( - builder, - (IRType*)originalElementType, - ptrVal); - auto var = builder.emitVar((IRType*)originalElementType); - builder.emitStore(var, unpackedVal); - use->set(var); - builder.setInsertAfter(user); - auto newVal = builder.emitLoad(var); loweredElementTypeInfo.convertOriginalToLowered - .applyDestinationDriven(builder, ptrVal, newVal); + .applyDestinationDriven(builder, addr, originalVal); } - break; - default: - break; + user->removeAndDeallocate(); } - }); - } + return; + } + default: + break; + } + // If the pointer is used in any other way that we don't recognize, + // preserve it as is without translation. + builder.setInsertBefore(user); + builder.replaceOperand(use, ptrVal); + }); + + if (!castInst->hasUses()) + castInst->removeAndDeallocate(); + } - // Replace all remaining uses of bufferType to loweredBufferType, these uses are - // non-operational and should be directly replaceable, such as uses in `IRFuncType`. - bufferType->replaceUsesWith(loweredBufferType); - bufferType->removeAndDeallocate(); + void collectInstsOfType(List& insts, IRInst* root, IROp op) + { + if (root->getOp() == op) + { + insts.add((IRCastStorageToLogicalBase*)root); + return; + } + for (auto child : root->getChildren()) + { + collectInstsOfType(insts, child, op); } + } - // Process all matrix address uses. - lowerMatrixAddresses(module, matrixAddrInsts); + void materializeStorageToLogicalCasts(IRInst* root) + { + // We will process all CastStorageToLogical insts first, before + // processing all CastStorageToLogicalDeref. + // This is because when we materialize a + // `store(CastStorageToLogical(addr), CastStorageToLogicalDeref(src))`, + // we can just fold out CastStorageToLogicalDeref and emit + // `store(addr, load(src))` instead. + // If we materialized `CastStorageToLogicalDeref` first we will + // miss this opportunity and generate more bloated code. + // + List castInsts; + collectInstsOfType(castInsts, root, kIROp_CastStorageToLogical); + for (auto inst : castInsts) + materializeStorageToLogicalCastsImpl(inst); + + castInsts.clear(); + collectInstsOfType(castInsts, root, kIROp_CastStorageToLogicalDeref); + for (auto inst : castInsts) + materializeStorageToLogicalCastsImpl(inst); } // Lower all getElementPtr insts of a lowered matrix out of existance. - void lowerMatrixAddresses(IRModule* module, List& matrixAddrInsts) + void lowerMatrixAddresses(IRModule* module, MatrixAddrWorkItem workItem) { IRBuilder builder(module); - for (auto workItem : matrixAddrInsts) - { - auto majorAddr = workItem.matrixAddrInst; - auto majorGEP = as(majorAddr); - SLANG_ASSERT(majorGEP); - auto loweredMatrixType = - cast(majorGEP->getBase()->getFullType())->getValueType(); - auto matrixTypeInfo = getTypeLoweringMap(workItem.config) - .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); - SLANG_ASSERT(matrixTypeInfo); - auto matrixType = as(matrixTypeInfo->originalType); - auto rowCount = getIntVal(matrixType->getRowCount()); - traverseUses( - majorAddr, - [&](IRUse* use) + auto majorAddr = workItem.matrixAddrInst; + auto majorGEP = as(majorAddr); + SLANG_ASSERT(majorGEP); + auto baseCast = as(majorGEP->getBase()); + SLANG_ASSERT(baseCast); + auto storageBase = baseCast->getOperand(0); + auto loweredMatrixType = cast(storageBase->getFullType())->getValueType(); + auto matrixTypeInfo = + getTypeLoweringMap(workItem.config).mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); + SLANG_ASSERT(matrixTypeInfo); + if (matrixTypeInfo->loweredType == matrixTypeInfo->originalType) + return; + auto matrixType = as(matrixTypeInfo->originalType); + auto colCount = getIntVal(matrixType->getColumnCount()); + traverseUses( + majorAddr, + [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) { - auto user = use->getUser(); - builder.setInsertBefore(user); - switch (user->getOp()) + case kIROp_Load: { - case kIROp_Load: + IRInst* resultInst = nullptr; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - IRInst* resultInst = nullptr; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - List args; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vector = - builder.emitLoad(builder.emitElementAddress(dataPtr, i)); - auto element = - builder.emitElementExtract(vector, majorGEP->getIndex()); - args.add(element); - } - resultInst = builder.emitMakeVector( - builder.getVectorType( - matrixType->getElementType(), - (IRIntegerValue)args.getCount()), - args); - } - else + List args; + for (IRIntegerValue i = 0; i < colCount; i++) { + auto vector = + builder.emitLoad(builder.emitElementAddress(dataPtr, i)); auto element = - builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - resultInst = builder.emitLoad(element); + builder.emitElementExtract(vector, majorGEP->getIndex()); + args.add(element); } - user->replaceUsesWith(resultInst); - user->removeAndDeallocate(); + resultInst = builder.emitMakeVector( + builder.getVectorType( + matrixType->getElementType(), + (IRIntegerValue)args.getCount()), + args); } - break; - case kIROp_Store: + else { - auto storeInst = cast(user); - if (storeInst->getOperand(0) != majorAddr) - break; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vectorAddr = builder.emitElementAddress(dataPtr, i); - auto elementAddr = builder.emitElementAddress( - vectorAddr, - majorGEP->getIndex()); - builder.emitStore( - elementAddr, - builder.emitElementExtract(storeInst->getVal(), i)); - } - } - else - { - auto rowAddr = - builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - builder.emitStore(rowAddr, storeInst->getVal()); - user->removeAndDeallocate(); - } - break; + auto element = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + resultInst = builder.emitLoad(element); } - case kIROp_GetElementPtr: + user->replaceUsesWith(resultInst); + user->removeAndDeallocate(); + } + break; + case kIROp_Store: + { + auto storeInst = cast(user); + if (storeInst->getOperand(0) != majorAddr) + break; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto gep2 = cast(user); - auto rowIndex = majorGEP->getIndex(); - auto colIndex = gep2->getIndex(); - if (getIntVal(matrixType->getLayout()) == - SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + for (IRIntegerValue i = 0; i < colCount; i++) { - Swap(rowIndex, colIndex); + auto vectorAddr = builder.emitElementAddress(dataPtr, i); + auto elementAddr = + builder.emitElementAddress(vectorAddr, majorGEP->getIndex()); + builder.emitStore( + elementAddr, + builder.emitElementExtract(storeInst->getVal(), i)); } - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType( - majorAddr->getDataType(), - matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); - auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); - gep2->replaceUsesWith(elementAddr); - gep2->removeAndDeallocate(); - break; } - default: - SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs " - "storage lowering."); + else + { + auto rowAddr = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + builder.emitStore(rowAddr, storeInst->getVal()); + user->removeAndDeallocate(); + } break; } - }); - } + case kIROp_GetElementPtr: + { + auto gep2 = cast(user); + auto rowIndex = majorGEP->getIndex(); + auto colIndex = gep2->getIndex(); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + Swap(rowIndex, colIndex); + } + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + storageBase, + matrixTypeInfo->loweredInnerStructKey); + auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); + auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); + gep2->replaceUsesWith(elementAddr); + gep2->removeAndDeallocate(); + break; + } + default: + SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs " + "storage lowering."); + break; + } + }); + if (!majorAddr->hasUses()) + majorAddr->removeAndDeallocate(); } }; diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h index 2c69c5476..9d6e53609 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.h +++ b/source/slang/slang-ir-lower-buffer-element-type.h @@ -10,7 +10,7 @@ struct IRType; struct BufferElementTypeLoweringOptions { - bool lowerBufferPointer = false; + bool lowerBufferPointer = true; // For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers. bool use16ByteArrayElementForConstantBuffer = false; diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index fd950b91a..e66617e72 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -135,6 +135,16 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner case kIROp_RWStructuredBufferGetElementPtr: outAddressSpace = AddressSpace::Global; return true; + case kIROp_Load: + { + auto addrSpace = getAddressSpaceFromVarType(inst->getDataType()); + if (addrSpace != AddressSpace::Generic) + { + outAddressSpace = addrSpace; + return true; + } + } + return false; default: return false; } @@ -256,10 +266,13 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) legalizeEntryPointVaryingParamsForMetal(module, sink, entryPoints); + processInst(module->getModuleInst(), sink); +} + +void specializeAddressSpaceForMetal(IRModule* module) +{ MetalAddressSpaceAssigner metalAddressSpaceAssigner; specializeAddressSpace(module, &metalAddressSpaceAssigner); - - processInst(module->getModuleInst(), sink); } } // namespace Slang diff --git a/source/slang/slang-ir-metal-legalize.h b/source/slang/slang-ir-metal-legalize.h index e49c1b9e0..98c19de60 100644 --- a/source/slang/slang-ir-metal-legalize.h +++ b/source/slang/slang-ir-metal-legalize.h @@ -7,4 +7,6 @@ namespace Slang class DiagnosticSink; void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink); +void specializeAddressSpaceForMetal(IRModule* module); + } // namespace Slang diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 94bb6b67c..4c10cf246 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -196,6 +196,9 @@ static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func) continue; } + if (destPtr->findDecorationImpl(kIROp_DisableCopyEliminationDecoration)) + continue; + // Check if we're storing a load result auto loadInst = as(storedValue); if (!loadInst) diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 29f1ec516..c4a155eec 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -168,6 +168,7 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext { case kIROp_Var: case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_Load: { // The address space of these insts should be assigned by the initial // address space assigner. @@ -204,16 +205,6 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext } } break; - case kIROp_Load: - { - if (auto addrSpace = - mapVarValueToAddrSpace.tryGetValue(inst->getOperand(0))) - { - mapInstToAddrSpace[inst] = *addrSpace; - changed = true; - } - } - break; case kIROp_Param: if (!isFirstBlock) { @@ -248,22 +239,24 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext if (callee) { List argAddrSpaces; - bool fullySpecialized = true; + bool hasSpecializableArg = false; for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); - auto argAddrSpace = getAddrSpace(arg); argAddrSpaces.add(getAddrSpace(arg)); - if (argAddrSpace == AddressSpace::Generic && - as(arg->getDataType())) + if (as(arg->getDataType())) { - fullySpecialized = false; - break; + hasSpecializableArg = true; } } - if (!fullySpecialized) + if (!hasSpecializableArg) + { + workList.add(callee); + break; + } + // If callee doesn't have a body, don't specialize. + if (!callee->getFirstBlock()) break; - FuncSpecializationKey key(callee, argAddrSpaces); IRFunc* specializedCallee = nullptr; if (IRFunc** specializedFunc = @@ -484,4 +477,14 @@ void propagateAddressSpaceFromInsts(List&& workList) } } +AddressSpace NoOpInitialAddressSpaceAssigner::getAddressSpaceFromVarType(IRInst*) +{ + return AddressSpace::Generic; +} + +AddressSpace NoOpInitialAddressSpaceAssigner::getLeafInstAddressSpace(IRInst*) +{ + return AddressSpace::Generic; +} + } // namespace Slang diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h index 7e5f0fd9b..89145cf87 100644 --- a/source/slang/slang-ir-specialize-address-space.h +++ b/source/slang/slang-ir-specialize-address-space.h @@ -24,6 +24,13 @@ struct InitialAddressSpaceAssigner virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) = 0; }; +struct NoOpInitialAddressSpaceAssigner : public InitialAddressSpaceAssigner +{ + virtual bool tryAssignAddressSpace(IRInst*, AddressSpace&) { return false; } + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type); + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst); +}; + /// Propagate address space information through the IR module. /// Specialize functions with reference/pointer parameters to use the correct address space /// based on the address space of the arguments. diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 2c4bd11cc..f795a6559 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -893,18 +893,22 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); auto funcType = as(funcValue->getFullType()); + bool argsChanged = false; for (UInt i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); auto paramType = funcType->getParamType(i); - if (as(paramType)) + if (auto ptrType = as(paramType)) { - // If the parameter has an explicit pointer type, - // then we know the user is using the variable pointer - // capability to pass a true pointer. - // In this case we should not rewrite the call. - newArgs.add(arg); - continue; + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) + { + // If the parameter has an explicit pointer type, + // then we know the user is using the variable pointer + // capability to pass a true pointer. + // In this case we should not rewrite the call. + newArgs.add(arg); + continue; + } } auto ptrType = as(arg->getDataType()); if (!as(arg->getDataType())) @@ -953,13 +957,26 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // If we reach here, we need to allocate a temp var. auto tempVar = builder.emitVar(ptrType->getValueType()); + builder.addDecoration(tempVar, kIROp_DisableCopyEliminationDecoration); auto load = builder.emitLoad(arg); builder.emitStore(tempVar, load); newArgs.add(tempVar); + argsChanged = true; + + // We may need to write the value back to the original pointer argument + // after the call. + // + // If callee doesn't modify the memory location, no need to write back. + if (funcType && funcType->getParamCount() > i && + as(funcType->getParamType(i))) + continue; + // If the buffer location is immutable, don't write back. + if (isPointerToImmutableLocation(root)) + continue; writeBacks.add(WriteBackPair{arg, tempVar}); } SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount()); - if (writeBacks.getCount()) + if (argsChanged) { auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs); for (auto wb : writeBacks) @@ -2297,19 +2314,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // so we need to update the function types to match that. updateFunctionTypes(); - // Lower all loads/stores from buffer pointers to use correct storage types. - // We didn't do the lowering for buffer pointers because we don't know which pointer - // types are actual storage buffer pointers until we propagated the address space of - // pointers in this pass. In the future we should consider separate out IRAddress as - // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can - // safely lower the pointer load stores early together with other buffer types. - BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; - bufferElementTypeLoweringOptions.lowerBufferPointer = true; - lowerBufferElementTypeToStorageType( - m_sharedContext->m_targetProgram, - m_module, - bufferElementTypeLoweringOptions); - // Look for structs that are both used as fields and marked with Block // decorations, and move the Block decoration to a wrapper struct. legalizeStructBlocks(); diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp index 9328a1de1..d34b3d25b 100644 --- a/source/slang/slang-ir-transform-params-to-constref.cpp +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -31,10 +31,7 @@ struct TransformParamsToConstRefContext case kIROp_StructType: case kIROp_ArrayType: case kIROp_UnsizedArrayType: - case kIROp_VectorType: - case kIROp_MatrixType: case kIROp_TupleType: - case kIROp_CoopVectorType: // valid type, continue to check break; default: @@ -44,58 +41,121 @@ struct TransformParamsToConstRefContext return true; } - void rewriteParamUseSitesToSupportConstRefUsage(HashSet& updatedParams) + void rewriteValueUsesToAddrUses(IRInst* newAddrInst) { - // Traverse the uses of our updated params to rewrite them. - // Assume a `in` parameter has been converted to a `constref` parameter. - for (auto param : updatedParams) + HashSet workListSet; + workListSet.add(newAddrInst); + List workList; + workList.add(newAddrInst); + auto _addToWorkList = [&](IRInst* inst) + { + if (workListSet.add(inst)) + workList.add(inst); + }; + for (Index i = 0; i < workList.getCount(); i++) { + auto inst = workList[i]; traverseUses( - param, + inst, [&](IRUse* use) { auto user = use->getUser(); + if (workListSet.contains(user)) + return; switch (user->getOp()) { case kIROp_FieldExtract: { // Transform the IRFieldExtract into a IRFieldAddress + if (!isUseBaseAddrOperand(use, user)) + break; auto fieldExtract = as(user); builder.setInsertBefore(fieldExtract); auto fieldAddr = builder.emitFieldAddress( fieldExtract->getBase(), fieldExtract->getField()); - auto loadInst = builder.emitLoad(fieldAddr); - fieldExtract->replaceUsesWith(loadInst); - fieldExtract->removeAndDeallocate(); - break; + fieldExtract->replaceUsesWith(fieldAddr); + _addToWorkList(fieldAddr); + return; } case kIROp_GetElement: { // Transform the IRGetElement into a IRGetElementPtr + if (!isUseBaseAddrOperand(use, user)) + break; auto getElement = as(user); builder.setInsertBefore(getElement); auto elemAddr = builder.emitElementAddress( getElement->getBase(), getElement->getIndex()); - auto loadInst = builder.emitLoad(elemAddr); - getElement->replaceUsesWith(loadInst); - getElement->removeAndDeallocate(); - break; + getElement->replaceUsesWith(elemAddr); + _addToWorkList(elemAddr); + return; } - default: + case kIROp_Store: { - // Insert a load before the user and replace the user with the load - builder.setInsertBefore(user); - auto loadInst = builder.emitLoad(param); - use->set(loadInst); + // If the current value is being stored into a write-once temp var that + // is immediately passed into a constref location in a call, we can get + // rid of the temp var and replace it with `inst` directly. + // (such temp var can be introduced during `updateCallSites` when we + // were processing the callee.) + // + auto dest = as(user)->getPtr(); + if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration)) + { + user->removeAndDeallocate(); + dest->replaceUsesWith(inst); + dest->removeAndDeallocate(); + return; + } break; } } + // Insert a load before the user and replace the user with the load + builder.setInsertBefore(user); + auto loadInst = builder.emitLoad(inst); + use->set(loadInst); }); } } + void rewriteParamUseSitesToSupportConstRefUsage(HashSet& updatedParams) + { + // Traverse the uses of our updated params to rewrite them. + // Assume a `in` parameter has been converted to a `constref` parameter. + for (auto param : updatedParams) + { + rewriteValueUsesToAddrUses(param); + } + } + + // Check if `load` is an `IRLoad(addr)` where `addr` is a immutable location. + IRInst* isLoadFromImmutableAddress(IRInst* load) + { + if (load->getOp() != kIROp_Load) + return nullptr; + auto addr = load->getOperand(0); + auto root = getRootAddr(addr); + if (!root) + return nullptr; + if (!root->getDataType()) + return nullptr; + switch (root->getDataType()->getOp()) + { + case kIROp_ConstantBufferType: + case kIROp_ConstRefType: + case kIROp_ParameterBlockType: + return addr; + default: + // Note that we should in general not assume a read-only StructuredBuffer or + // a pointer with read-only access as an immutable location due to potential aliasing. + // We could introduce a compiler flag to turn on optimizations on these buffer types + // assuming there is no aliasing. + break; + } + return nullptr; + } + // Update call sites to pass an address instead of value for each updated-param void updateCallSites(IRFunc* func, HashSet& updatedParams) { @@ -119,10 +179,19 @@ struct TransformParamsToConstRefContext newArgs.add(arg); continue; } - - auto tempVar = builder.emitVar(arg->getFullType()); - builder.emitStore(tempVar, arg); - newArgs.add(tempVar); + if (auto addr = isLoadFromImmutableAddress(arg)) + { + // If existing argument is a load from an immutable buffer address, + // we can pass in the address as is, without making a temporary copy. + newArgs.add(addr); + } + else + { + auto tempVar = builder.emitVar(arg->getFullType()); + builder.addDecoration(tempVar, kIROp_TempCallArgImmutableVarDecoration); + builder.emitStore(tempVar, arg); + newArgs.add(tempVar); + } } // Create new call with updated arguments @@ -177,6 +246,25 @@ struct TransformParamsToConstRefContext { HashSet updatedParams; + // If the function is used in any way that is not understood by the + // compiler, do not modify it. + // For example, if the function is used as callback, we must preserve + // its signature. + for (auto use = func->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (as(user)) + continue; + if (auto call = as(user)) + { + if (call->getCalleeUse() == use) + continue; + } + // If we reach here, we encountered a non-call use of the func, + // we will stop processing. + return; + } + // First pass: Transform parameter types for (auto param = func->getFirstParam(); param; param = param->getNextParam()) { @@ -192,7 +280,7 @@ struct TransformParamsToConstRefContext // This allows us to pass the address of variables directly into a function, // giving us the choice to remove copies into a parameter. auto paramType = param->getDataType(); - auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal); + auto constRefType = builder.getConstRefType(paramType, AddressSpace::Generic); param->setFullType(constRefType); changed = true; @@ -205,6 +293,8 @@ struct TransformParamsToConstRefContext return; } + fixUpFuncType(func); + // Second pass: Update function body according to the new `constref` parameters rewriteParamUseSitesToSupportConstRefUsage(updatedParams); diff --git a/source/slang/slang-ir-undo-param-copy.cpp b/source/slang/slang-ir-undo-param-copy.cpp index d8aac7201..75ee8a03a 100644 --- a/source/slang/slang-ir-undo-param-copy.cpp +++ b/source/slang/slang-ir-undo-param-copy.cpp @@ -6,7 +6,7 @@ namespace Slang { -// This pass transforms variables decorated with TempCallArgVarDecoration +// This pass transforms variables decorated with TempCallArgImmutableVarDecoration // by replacing them with direct references to the original parameters. // This is important for CUDA/OptiX targets where functions like 'IgnoreHit' // can prevent copy-back operations from executing. @@ -52,7 +52,17 @@ struct UndoParameterCopyVisitor { if (auto varInst = as(inst)) { - if (varInst->findDecoration()) + bool isTempCallArgVar = false; + for (auto decor : varInst->getDecorations()) + { + if (as(decor) || + as(decor)) + { + isTempCallArgVar = true; + break; + } + } + if (isTempCallArgVar) { IRStore* initializingStore = nullptr; IRInst* originalParamPtr = nullptr; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 9b852b803..8584ea95e 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2128,6 +2128,40 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } +IRType* getElementType(IRBuilder& builder, IRType* valueType) +{ + valueType = (IRType*)unwrapAttributedType(valueType); + if (auto arrayType = as(valueType)) + { + return arrayType->getElementType(); + } + else if (auto vectorType = as(valueType)) + { + return vectorType->getElementType(); + } + else if (auto basicType = as(valueType)) + { + return basicType; + } + else if (auto coopVecType = as(valueType)) + { + return coopVecType->getElementType(); + } + else if (auto matrixType = as(valueType)) + { + return builder.getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); + } + else if (auto coopMatType = as(valueType)) + { + return coopMatType->getElementType(); + } + else if (auto hlslInputPatchType = as(valueType)) + { + return hlslInputPatchType->getElementType(); + } + return nullptr; +} + Int getSpecializationConstantId(IRGlobalParam* param) { auto layout = findVarLayout(param); @@ -2483,4 +2517,47 @@ bool isIROpaqueType(IRType* type) } } +bool isPointerToImmutableLocation(IRInst* loc) +{ + switch (loc->getOp()) + { + case kIROp_GetStructuredBufferPtr: + case kIROp_ImageSubscript: + return isPointerToImmutableLocation(loc->getOperand(0)); + default: + break; + } + + auto type = loc->getDataType(); + if (!type) + return false; + + switch (type->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLByteAddressBufferType: + case kIROp_ConstantBufferType: + case kIROp_ParameterBlockType: + return true; + default: + break; + } + + if (auto textureType = as(type)) + return textureType->getAccess() == SLANG_RESOURCE_ACCESS_READ; + + if (auto ptrType = as(type)) + { + switch (ptrType->getAddressSpace()) + { + case AddressSpace::BuiltinInput: + case AddressSpace::Input: + case AddressSpace::MetalObjectData: + case AddressSpace::Uniform: + case AddressSpace::UniformConstant: + return true; + } + } + return false; +} } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index b7aafea56..c0410fa3c 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -376,6 +376,10 @@ void verifyComputeDerivativeGroupModifiers( int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); +// Retrieves the element type of a pointer, buffer, array, vector or matrix type. +// This is the result type of a ElementExtract operation on a value of `type`. +IRType* getElementType(IRBuilder& builder, IRType* type); + Int getSpecializationConstantId(IRGlobalParam* param); void legalizeDefUse(IRGlobalValueWithCode* func); @@ -418,6 +422,22 @@ IRType* getUnsignedTypeFromSignedType(IRBuilder* builder, IRType* type); bool isSignedType(IRType* type); bool isIROpaqueType(IRType* type); + +// Returns true if the memory location pointed to by `ptrInst` is immutable. +// An immutable location is the memory region that can't be modified by the user code. +// Examples are ConstantBuffer and shader resource contents(e.g. StructuredBuffer). +// Note that this is to be disguished from the access qualifier of the pointer itself, +// e.g. a `ptrInst` of type `Ptr` may still point to a mutable location, +// so this function returns false in that case. +bool isPointerToImmutableLocation(IRInst* ptrInst); + +// Check if `use` is the `baseAddr` operand of a GetElement/FieldExtract inst. +// This is true if `use` is the first operand of the user inst. +inline bool isUseBaseAddrOperand(IRUse* use, IRInst* user) +{ + return user->getOperandUse(0) == use; +} + } // namespace Slang #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e9d1a1199..92543c952 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -386,6 +386,20 @@ IRType* tryGetPointedToType(IRBuilder* builder, IRType* type) return nullptr; } +IRType* tryGetPointedToOrBufferElementType(IRBuilder* builder, IRType* type) +{ + if (auto rateQualType = as(type)) + { + type = rateQualType->getValueType(); + } + auto resultType = tryGetPointedToType(builder, type); + if (resultType) + return resultType; + if (auto structuredBufferType = as(type)) + return structuredBufferType->getElementType(); + return nullptr; +} + // IRBlock @@ -5480,39 +5494,19 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) } IRType* type = nullptr; valueType = unwrapAttributedType(valueType); - if (auto arrayType = as(valueType)) - { - type = arrayType->getElementType(); - } - else if (auto vectorType = as(valueType)) - { - type = vectorType->getElementType(); - } - else if (auto coopVecType = as(valueType)) - { - type = coopVecType->getElementType(); - } - else if (auto matrixType = as(valueType)) - { - type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); - } - else if (auto coopMatType = as(valueType)) - { - type = coopMatType->getElementType(); - } - else if (const auto basicType = as(valueType)) + if (as(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. return basePtr; } - else if (const auto tupleType = as(valueType)) - { - SLANG_ASSERT(as(index)); - type = (IRType*)tupleType->getOperand(getIntVal(index)); - } - else if (auto hlslInputPatchType = as(valueType)) + type = getElementType(*this, (IRType*)valueType); + if (!type) { - type = hlslInputPatchType->getElementType(); + if (const auto tupleType = as(valueType)) + { + SLANG_ASSERT(as(index)); + type = (IRType*)tupleType->getOperand(getIntVal(index)); + } } SLANG_RELEASE_ASSERT(type); @@ -6179,6 +6173,24 @@ IRInst* IRBuilder::emitCastIntToPtr(IRType* ptrType, IRInst* val) return inst; } +IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType) +{ + if (type == val->getDataType()) + return val; + IRInst* args[] = {val, bufferType}; + return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args); +} + +IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref( + IRType* type, + IRInst* val, + IRInst* bufferType) +{ + IRInst* args[] = {val, bufferType}; + return (IRCastStorageToLogicalDeref*) + emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); +} + IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type) { auto inst = createInst(this, kIROp_GlobalConstant, type); @@ -6627,6 +6639,7 @@ IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText) IRInst* IRBuilder::emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index) { const auto sbt = cast(structuredBuffer->getDataType()); + SLANG_ASSERT(sbt); const auto t = getPtrType(sbt->getElementType()); IRInst* const operands[2] = {structuredBuffer, index}; const auto i = createInst( diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 69a000b81..161e70b25 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -695,6 +695,12 @@ struct IRInst IRUse* getOperands(); + IRUse* getOperandUse(UInt index) + { + SLANG_ASSERT(index < getOperandCount()); + return getOperands() + index; + } + IRInst* getOperand(UInt index) { SLANG_ASSERT(index < getOperandCount()); @@ -1533,7 +1539,7 @@ struct IRUniformParameterGroupType : IRParameterGroupType FIDDLE() -struct IRGLSLShaderStorageBufferType : IRBuiltinGenericType +struct IRGLSLShaderStorageBufferType : IRPointerLikeType { FIDDLE(leafInst()) IRType* getDataLayout() { return (IRType*)getOperand(1); } @@ -1760,6 +1766,8 @@ struct IRGetStringHash : IRInst /// The given IR `builder` will be used if new instructions need to be created. IRType* tryGetPointedToType(IRBuilder* builder, IRType* type); +IRType* tryGetPointedToOrBufferElementType(IRBuilder* builder, IRType* type); + FIDDLE() struct IRFuncType : IRType { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ca32b0f2d..5d83a351f 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2062,6 +2062,21 @@ struct ValLoweringVisitor : ValVisitorgetOp() == kIROp_IntLit) + { + auto constVal = as(inst); + return context->irBuilder->getIntValue( + context->irBuilder->getUInt64Type(), + constVal->value.intVal); + } + else + { + return context->irBuilder->emitCast(context->irBuilder->getUInt64Type(), inst); + } + } + IRType* visitPtrType(PtrType* type) { auto astValueType = type->getValueType(); @@ -2072,12 +2087,14 @@ struct ValLoweringVisitor : ValVisitorgetAccessQualifier()) { - accessQualifier = getSimpleVal(context, lowerVal(context, astAccessQualifier)); + accessQualifier = + convertToUInt64Value(getSimpleVal(context, lowerVal(context, astAccessQualifier))); } if (auto astAddrSpace = type->getAddressSpace()) { - addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace)); + addrSpace = + convertToUInt64Value(getSimpleVal(context, lowerVal(context, astAddrSpace))); } else { diff --git a/tests/compute/byte-address-buffer-array.slang b/tests/compute/byte-address-buffer-array.slang index 90cdb2261..58862d1ac 100644 --- a/tests/compute/byte-address-buffer-array.slang +++ b/tests/compute/byte-address-buffer-array.slang @@ -1,7 +1,7 @@ // byte-address-buffer-array.slang //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -d3d12 -profile cs_6_0 -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -cuda -profile cs_6_0 -shaderobj -output-using-type -//DISABLED_TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -output-using-type //TEST:SIMPLE(filecheck=CHECK2):-target hlsl -entry computeMain -stage compute //TEST:SIMPLE(filecheck=CHECK3):-target spirv -entry computeMain -stage compute @@ -10,7 +10,7 @@ // Confirm compilation of `(RW)ByteAddressBuffer` with aligned load / stores to wider data types. //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=buffer -[vk::binding(2, 3)] RWByteAddressBuffer buffer; +RWByteAddressBuffer buffer; struct Block { float4 val[2]; }; diff --git a/tests/optimization/arrray-storage-lowering.slang b/tests/optimization/arrray-storage-lowering.slang new file mode 100644 index 000000000..42bb8f127 --- /dev/null +++ b/tests/optimization/arrray-storage-lowering.slang @@ -0,0 +1,42 @@ +// TEST:SIMPLE(filecheck=SPV): -target spirv + +// TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type -emit-spirv-directly + +struct DoubleNested +{ + int4x3 matrix; + int getMatVal(int i, int j) { return matrix[i][j]; } +} + +struct Nested +{ + bool values[4]; + DoubleNested doubleNested; + int getVal(int id) { return (int)values[0] + doubleNested.getMatVal(0, 1); } +} + +struct Params +{ + Nested nested; + + int getVal(int id) { return nested.getVal(id) + nested.getVal(id + 1); } +} + +// TEST_INPUT: set outputBuffer = out ubuffer(data=[0], stride=4) +RWStructuredBuffer outputBuffer; + +// TEST_INPUT:set gParams = cbuffer(data=[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]) +ConstantBuffer gParams; + +// TEST_INPUT: set gDoubleNested = ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12]) +uniform DoubleNested *gDoubleNested; + +// CHECK: 9 + +[numthreads(1,1,1)] +void computeMain(int id: SV_DispatchThreadID) +{ + outputBuffer[0].xyz = gParams.getVal(id) + gDoubleNested.getMatVal(1, 1); +} + +// SPV-NOT: OpCompositeConstruct diff --git a/tests/optimization/get-array-element.slang b/tests/optimization/get-array-element.slang new file mode 100644 index 000000000..16a71aee2 --- /dev/null +++ b/tests/optimization/get-array-element.slang @@ -0,0 +1,17 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +int test(int arr[32]) { + int sum = 0; + for (int i =0; i < 32; i++) sum += arr[i]; + return sum; +} + +uniform int gArr[32]; +uniform int* result; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK-NOT: OpCompositeConstruct + *result = test(gArr); +} \ No newline at end of file diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang b/tests/pipeline/rasterization/get-attribute-at-vertex.slang index 342796b90..f73ebe86f 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang @@ -9,8 +9,11 @@ // CHECK-SPIRV: OpExtension "SPV_KHR_fragment_shader_barycentric" // CHECK-SPIRV: OpEntryPoint Fragment %main "main" -// CHECK-SPIRV: OpDecorate %{{.*}} BuiltIn BaryCoordKHR -// CHECK-SPIRV: OpDecorate %{{.*}} BuiltIn BaryCoordNoPerspKHR + +// CHECK-SPIRV-DAG: OpDecorate %{{.*}} BuiltIn BaryCoordKHR + +// CHECK-SPIRV-DAG: OpDecorate %{{.*}} BuiltIn BaryCoordNoPerspKHR + // CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_0 // CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_1 // CHECK-SPIRV: %{{.*}} = OpAccessChain %_ptr_Input_{{.*}} %{{.*}} %uint_2 diff --git a/tests/spirv/aligned-load-store.slang b/tests/spirv/aligned-load-store.slang index c2f50b66c..8131e1cc7 100644 --- a/tests/spirv/aligned-load-store.slang +++ b/tests/spirv/aligned-load-store.slang @@ -4,8 +4,6 @@ // CHECK: OpStore {{.*}} Aligned 16 // CHECK: OpLoad {{.*}} Aligned 16 -// CHECK: OpLoad {{.*}} Aligned 16 -// CHECK: OpStore {{.*}} Aligned 16 // CHECK: OpStore {{.*}} Aligned 16 uniform float4* data; diff --git a/tests/spirv/buffer-pointer-matrix-layout.slang b/tests/spirv/buffer-pointer-matrix-layout.slang index cbb8f2857..4d80419e4 100644 --- a/tests/spirv/buffer-pointer-matrix-layout.slang +++ b/tests/spirv/buffer-pointer-matrix-layout.slang @@ -1,33 +1,50 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -stage compute -entry main -matrix-layout-column-major +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -output-using-type -xslang -matrix-layout-column-major -emit-spirv-directly -// CHECK: OpLoad {{.*}} Aligned 4 +// TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0 12.0],stride=4) +uniform float3x4 *ptr; -struct Push -{ - float3x4* ptr; -}; +// TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0],stride=4) +RWStructuredBuffer outputBuffer; -[[vk::push_constant]] Push push; [shader("compute")] [numthreads(1, 1, 1)] -void main(uint3 dtid : SV_DispatchThreadID) -{ +void computeMain(uint3 dtid: SV_DispatchThreadID) +{ // This matrix is in memry column major. Slang respects this here and load it properly! - float3x4 correctly_read_matrix = *push.ptr; - printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n", - correctly_read_matrix[0][0], correctly_read_matrix[0][1], correctly_read_matrix[0][2], correctly_read_matrix[0][3], - correctly_read_matrix[1][0], correctly_read_matrix[1][1], correctly_read_matrix[1][2], correctly_read_matrix[1][3] - ); - printf("(%f,%f,%f,%f)\n\n", - correctly_read_matrix[2][0], correctly_read_matrix[2][1], correctly_read_matrix[2][2], correctly_read_matrix[2][3] - ); - // With this syntax however, Slang ignores the column major setting and loads it as it it was row major! - float3x4 broken_matrix = push.ptr[0]; - printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n", - broken_matrix[0][0], broken_matrix[0][1], broken_matrix[0][2], broken_matrix[0][3], - broken_matrix[1][0], broken_matrix[1][1], broken_matrix[1][2], broken_matrix[1][3] - ); - printf("(%f,%f,%f,%f)\n\n", - broken_matrix[2][0], broken_matrix[2][1], broken_matrix[2][2], broken_matrix[2][3] - ); + float3x4 correctly_read_matrix = *ptr; + outputBuffer[0] = correctly_read_matrix[0][0]; + outputBuffer[1] = correctly_read_matrix[0][1]; + outputBuffer[2] = correctly_read_matrix[0][2]; + outputBuffer[3] = correctly_read_matrix[0][3]; + outputBuffer[4] = correctly_read_matrix[1][0]; + outputBuffer[5] = correctly_read_matrix[1][1]; + outputBuffer[6] = correctly_read_matrix[1][2]; + outputBuffer[7] = correctly_read_matrix[1][3]; + // CHECK: 1.0 + // CHECK: 4.0 + // CHECK: 7.0 + // CHECK: 10.0 + // CHECK: 2.0 + // CHECK: 5.0 + // CHECK: 8.0 + // CHECK: 11.0 + + // With this syntax however, Slang was ignoring the column major setting and loads it as it it was row major! + float3x4 broken_matrix = ptr[0]; + outputBuffer[8] = broken_matrix[0][0]; + outputBuffer[9] = broken_matrix[0][1]; + outputBuffer[10] = broken_matrix[0][2]; + outputBuffer[11] = broken_matrix[0][3]; + outputBuffer[12] = broken_matrix[1][0]; + outputBuffer[13] = broken_matrix[1][1]; + outputBuffer[14] = broken_matrix[1][2]; + outputBuffer[15] = broken_matrix[1][3]; + // CHECK: 1.0 + // CHECK: 4.0 + // CHECK: 7.0 + // CHECK: 10.0 + // CHECK: 2.0 + // CHECK: 5.0 + // CHECK: 8.0 + // CHECK: 11.0 } \ No newline at end of file diff --git a/tests/spirv/geometry-shader-sub-func.slang b/tests/spirv/geometry-shader-sub-func.slang index 6c6944f31..20634ea67 100644 --- a/tests/spirv/geometry-shader-sub-func.slang +++ b/tests/spirv/geometry-shader-sub-func.slang @@ -36,7 +36,7 @@ void main( { CoarseVertex coarseVertex = coarseVertices[ii]; RasterVertex rasterVertex; - rasterVertex.position = coarseVertex.position; + rasterVertex.position = coarseVertex.position; rasterVertex.color = coarseVertex.color; rasterVertex.id = coarseVertex.id + primitiveID; appendVertex(outputStream, rasterVertex); diff --git a/tests/spirv/large-struct.slang b/tests/spirv/large-struct.slang index 2d79c0aaf..e4cbd6d1c 100644 --- a/tests/spirv/large-struct.slang +++ b/tests/spirv/large-struct.slang @@ -1,5 +1,5 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460 -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type -xslang -g0 //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-d3d12 -compute -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type diff --git a/tests/spirv/pointer-2.slang b/tests/spirv/pointer-2.slang index 1f2b2d0ea..b93ca32b4 100644 --- a/tests/spirv/pointer-2.slang +++ b/tests/spirv/pointer-2.slang @@ -1,6 +1,6 @@ -//TEST:SIMPLE(filecheck=CHECK_SPV): -entry vertexMain -stage vertex -emit-spirv-directly -target spirv -//TEST:SIMPLE(filecheck=CHECK_SPV_VIA_GLSL): -entry vertexMain -stage vertex -emit-spirv-via-glsl -target spirv -//TEST:SIMPLE(filecheck=CHECK_GLSL): -entry vertexMain -stage vertex -target glsl +// TEST:SIMPLE(filecheck=CHECK_GLSL): -entry vertexMain -stage vertex -target glsl +// TEST:SIMPLE(filecheck=CHECK_SPV): -entry vertexMain -stage vertex -emit-spirv-directly -target spirv +// TEST:SIMPLE(filecheck=CHECK_SPV_VIA_GLSL): -entry vertexMain -stage vertex -emit-spirv-via-glsl -target spirv struct Inner1 { diff --git a/tests/spirv/spec-constant-operations.slang b/tests/spirv/spec-constant-operations.slang index 86d16ef34..7c7b9bc60 100644 --- a/tests/spirv/spec-constant-operations.slang +++ b/tests/spirv/spec-constant-operations.slang @@ -11,8 +11,6 @@ RWStructuredBuffer outputBuffer; // CHECK-DAG: OpSpecConstant %float 1 // CHECK-DAG: OpSpecConstant %ulong 256 // CHECK-DAG: OpSpecConstant %float 100 -// CHECK-DAG: OpSpecConstantOp %half FConvert -// CHECK-DAG: OpSpecConstantOp %int UConvert // CHECK-NOT: OpSpecConstantOp {{.*}} FAdd // CHECK-NOT: OpSpecConstantOp {{.*}} FSub diff --git a/tests/spirv/spirv-debug-break.slang b/tests/spirv/spirv-debug-break.slang index e57024037..67b3e975b 100644 --- a/tests/spirv/spirv-debug-break.slang +++ b/tests/spirv/spirv-debug-break.slang @@ -1,5 +1,5 @@ // spirv-instruction.slang -//TEST(compute, vulkan):SIMPLE(filecheck=CHECK):-target glsl -entry computeMain -stage compute +// TEST:SIMPLE(filecheck=CHECK):-target glsl -entry computeMain -stage compute [[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]] void _spvDebugBreak(int v); diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index bb213cdaa..c03afaa9e 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -1015,9 +1015,12 @@ struct ShaderInputLayoutParser for (auto& line : lines) { lineNum++; - if (line.startsWith("//TEST_INPUT:")) + if (!line.startsWith("//")) + continue; + line = line.getUnownedSlice().tail(2).trim(); + if (line.startsWith("TEST_INPUT:")) { - auto lineContent = line.subString(13, line.getLength() - 13); + auto lineContent = line.subString(11, line.getLength() - 11); Misc::TokenReader parser(lineContent); try { -- cgit v1.2.3