diff options
| author | Yong He <yonghe@outlook.com> | 2022-02-25 20:49:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-02-25 20:49:31 -0800 |
| commit | c31577953d5041c82375c22d847c2eba06106c58 (patch) | |
| tree | bc685a8b63fc13cb85d160ae13df950056ca6e91 | |
| parent | 8990d270e3a0c01b1f7abbf4f79556c5ef82a096 (diff) | |
Improved SCCP, inlining and resource specialization passes, legalize `ImageSubscript` for GLSL (#2146)
44 files changed, 1927 insertions, 257 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 4c6bcbea2..c2387346d 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -377,6 +377,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-sccp.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-dispatch.h" />
@@ -386,6 +387,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-string-hash.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-strip-witness-tables.h" />
@@ -508,6 +510,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-sccp.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-dispatch.cpp" />
@@ -517,6 +520,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-string-hash.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-strip-witness-tables.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index caea98ea9..d86cdcbea 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -228,6 +228,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-sccp.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -255,6 +258,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -617,6 +623,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-sccp.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -644,6 +653,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 2e1ab33f2..d3d77b804 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -64,6 +64,11 @@ __target_intrinsic(hlsl, "NvInterlockedAddUint64($0, $1, $2)") [__requiresNVAPI] uint2 __atomicAdd(RWByteAddressBuffer buf, uint offset, uint2); +// atomic add for hlsl using SM6.6 +__target_intrinsic(hlsl, "$0.InterlockedAdd64($1, $2, $3)") +void __atomicAdd(RWByteAddressBuffer buf, uint offset, int64_t value, out int64_t originalValue); +__target_intrinsic(hlsl, "$0.InterlockedAdd64($1, $2, $3)") +void __atomicAdd(RWByteAddressBuffer buf, uint offset, uint64_t value, out uint64_t originalValue); // Int versions require glsl 4.30 // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/atomicAdd.xhtml @@ -81,6 +86,11 @@ __glsl_version(430) __glsl_extension(GL_EXT_shader_atomic_int64) int64_t __atomicAdd(__ref int64_t value, int64_t amount); +__target_intrinsic(glsl, "atomicAdd($0, $1)") +__glsl_version(430) +__glsl_extension(GL_EXT_shader_atomic_int64) +uint64_t __atomicAdd(__ref uint64_t value, uint64_t amount); + // Cas - Compare and swap // Helper for HLSL, using NVAPI @@ -89,6 +99,17 @@ __target_intrinsic(hlsl, "NvInterlockedCompareExchangeUint64($0, $1, $2, $3)") [__requiresNVAPI] uint2 __cas(RWByteAddressBuffer buf, uint offset, uint2 compareValue, uint2 value); +// CAS using SM6.6 +__target_intrinsic(hlsl, "$0.InterlockedCompareExchange64($1, $2, $3, $4)") +void __cas(RWByteAddressBuffer buf, uint offset, in int64_t compare_value, in int64_t value, out int64_t original_value); +__target_intrinsic(hlsl, "$0.InterlockedCompareExchange64($1, $2, $3, $4)") +void __cas(RWByteAddressBuffer buf, uint offset, in uint64_t compare_value, in uint64_t value, out uint64_t original_value); + +__target_intrinsic(glsl, "atomicCompSwap($0, $1, $2)") +__glsl_version(430) +__glsl_extension(GL_EXT_shader_atomic_int64) +uint64_t __cas(__ref int64_t ioValue, int64_t compareValue, int64_t newValue); + __target_intrinsic(glsl, "atomicCompSwap($0, $1, $2)") __glsl_version(430) __glsl_extension(GL_EXT_shader_atomic_int64) @@ -482,6 +503,51 @@ ${{{{ return __atomicExchange(buf[byteAddress / 8], value); } + // SM6.6 6 64bit atomics. + __specialized_for_target(hlsl) + void InterlockedAdd64(uint byteAddress, int64_t valueToAdd, out int64_t outOriginalValue) + { + __atomicAdd(this, byteAddress, valueToAdd, outOriginalValue); + } + __specialized_for_target(glsl) + void InterlockedAdd64(uint byteAddress, int64_t valueToAdd, out int64_t originalValue) + { + RWStructuredBuffer<int64_t> buf = __getEquivalentStructuredBuffer<int64_t>(this); + originalValue = __atomicAdd(buf[byteAddress / 8], valueToAdd); + } + __specialized_for_target(hlsl) + void InterlockedAdd64(uint byteAddress, uint64_t valueToAdd, out uint64_t outOriginalValue) + { + __atomicAdd(this, byteAddress, valueToAdd, outOriginalValue); + } + __specialized_for_target(glsl) + void InterlockedAdd64(uint byteAddress, uint64_t valueToAdd, out uint64_t originalValue) + { + RWStructuredBuffer<uint64_t> buf = __getEquivalentStructuredBuffer<uint64_t>(this); + originalValue = __atomicAdd(buf[byteAddress / 8], valueToAdd); + } + __specialized_for_target(hlsl) + void InterlockedCompareExchange64(uint byteAddress, int64_t compareValue, int64_t value, out int64_t outOriginalValue) + { + __cas(this, byteAddress, compareValue, value, outOriginalValue); + } + __specialized_for_target(glsl) + void InterlockedCompareExchange64(uint byteAddress, int64_t compareValue, int64_t value, out int64_t outOriginalValue) + { + RWStructuredBuffer<int64_t> buf = __getEquivalentStructuredBuffer<int64_t>(this); + outOriginalValue = __cas(buf[byteAddress / 8], compareValue, value); + } + __specialized_for_target(hlsl) + void InterlockedCompareExchange64(uint byteAddress, uint64_t compareValue, uint64_t value, out uint64_t outOriginalValue) + { + __cas(this, byteAddress, compareValue, value, outOriginalValue); + } + __specialized_for_target(glsl) + void InterlockedCompareExchange64(uint byteAddress, uint64_t compareValue, uint64_t value, out uint64_t outOriginalValue) + { + RWStructuredBuffer<uint64_t> buf = __getEquivalentStructuredBuffer<uint64_t>(this); + outOriginalValue = __cas(buf[byteAddress / 8], compareValue, value); + } ${{{{ } }}}} diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 03943abb3..6b77cba6a 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -893,10 +893,20 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) return; } case BaseType::UInt8: + { + m_writer->emit(UInt(uint8_t(litInst->value.intVal))); + m_writer->emit("U"); + break; + } case BaseType::UInt16: + { + m_writer->emit(UInt(uint16_t(litInst->value.intVal))); + m_writer->emit("U"); + break; + } case BaseType::UInt: { - m_writer->emit(UInt(litInst->value.intVal)); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); m_writer->emit("U"); break; } @@ -1010,6 +1020,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) // case kIROp_makeStruct: case kIROp_makeArray: + case kIROp_swizzleSet: return false; } diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 27eb75d34..f7850179d 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -895,6 +895,24 @@ void CUDASourceEmitter::emitSimpleFuncImpl(IRFunc* func) CLikeSourceEmitter::emitSimpleFuncImpl(func); } +void CUDASourceEmitter::emitSimpleValueImpl(IRInst* inst) +{ + // Make sure we convert float to half when emitting a half literal to avoid + // overload ambiguity errors from CUDA. + if (inst->getOp() == kIROp_FloatLit) + { + if (inst->getDataType()->getOp() == kIROp_HalfType) + { + m_writer->emit("__half("); + CLikeSourceEmitter::emitSimpleValueImpl(inst); + m_writer->emit(")"); + return; + } + } + CLikeSourceEmitter::emitSimpleValueImpl(inst); +} + + void CUDASourceEmitter::emitSemanticsImpl(IRInst* inst) { Super::emitSemanticsImpl(inst); diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index 07f429898..ad2f8ed0a 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -72,6 +72,7 @@ protected: virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE; virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 441fc94d7..44326a18e 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -774,38 +774,38 @@ void GLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) { emitType(type); m_writer->emit("("); - m_writer->emit(litInst->value.intVal); + m_writer->emit(int8_t(litInst->value.intVal)); m_writer->emit(")"); return; } case BaseType::Int16: { - m_writer->emit(litInst->value.intVal); + m_writer->emit(int16_t(litInst->value.intVal)); m_writer->emit("S"); return; } case BaseType::Int: { - m_writer->emit(litInst->value.intVal); + m_writer->emit(int32_t(litInst->value.intVal)); return; } case BaseType::UInt8: { emitType(type); m_writer->emit("("); - m_writer->emit(UInt(litInst->value.intVal)); + m_writer->emit(UInt(uint8_t(litInst->value.intVal))); m_writer->emit("U)"); return; } case BaseType::UInt16: { - m_writer->emit(UInt(litInst->value.intVal)); + m_writer->emit(UInt(uint16_t(litInst->value.intVal))); m_writer->emit("US"); return; } case BaseType::UInt: { - m_writer->emit(UInt(litInst->value.intVal)); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); m_writer->emit("U"); return; } @@ -1636,6 +1636,26 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu emitOperand(inst->getOperand(0), outerPrec); return true; } + case kIROp_ImageLoad: + { + m_writer->emit("imageLoad("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_ImageStore: + { + m_writer->emit("imageStore("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } case kIROp_StructuredBufferLoad: { auto outerPrec = inOuterPrec; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 378732fb3..a0ac30857 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -26,11 +26,13 @@ #include "slang-ir-optix-entry-point-uniforms.h" #include "slang-ir-restructure.h" #include "slang-ir-restructure-scoping.h" +#include "slang-ir-sccp.h" #include "slang-ir-specialize.h" #include "slang-ir-specialize-arrays.h" #include "slang-ir-specialize-buffer-load-arg.h" #include "slang-ir-specialize-resources.h" #include "slang-ir-ssa.h" +#include "slang-ir-ssa-simplification.h" #include "slang-ir-strip-witness-tables.h" #include "slang-ir-synthesize-active-mask.h" #include "slang-ir-union.h" @@ -324,10 +326,13 @@ Result linkAndOptimizeIR( specializeModule(irModule); dumpIRIfEnabled(compileRequest, irModule, "AFTER-SPECIALIZE"); + applySparseConditionalConstantPropagation(irModule); eliminateDeadCode(irModule); lowerReinterpret(targetRequest, irModule, sink); + validateIRModuleIfEnabled(compileRequest, irModule); + // For targets that supports dynamic dispatch, we need to lower the // generics / interface types to ordinary functions and types using // function pointers. @@ -359,10 +364,7 @@ Result linkAndOptimizeIR( // up downstream passes like type legalization, so we // will run a DCE pass to clean up after the specialization. // - // TODO: Are there other cleanup optimizations we should - // apply at this point? - // - eliminateDeadCode(irModule); + simplifyIR(irModule); #if 0 dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); #endif @@ -435,7 +437,7 @@ Result linkAndOptimizeIR( // to see if we can clean up any temporaries created by legalization. // (e.g., things that used to be aggregated might now be split up, // so that we can work with the individual fields). - constructSSA(irModule); + simplifyIR(irModule); #if 0 dumpIRIfEnabled(compileRequest, irModule, "AFTER SSA"); @@ -450,36 +452,12 @@ Result linkAndOptimizeIR( // Many of our targets place restrictions on how certain // resource types can be used, so that having them as // function parameters, reults, etc. is invalid. - // To clean this up, we apply two kinds of specialization: - // - // * Specalize call sites based on the actual resources - // that a called function will return/output. - // - // * Specialize called functions based on teh actual resources - // passed ass input at specific call sites. - // - // Because the legalization may depend on what target - // we are compiling for (certain things might be okay - // for D3D targets that are not okay for Vulkan), we - // pass down the target request along with the IR. - // - specializeResourceOutputs(compileRequest, targetRequest, irModule); - // - // After specialization of function outputs, we may find that there - // are cases where opaque-typed local variables can now be eliminated - // and turned into SSA temporaries. Such optimization may enable - // the following passes to "see" and specialize more cases. - // - // TODO: We should consider whether there are cases that will require - // iterating the passes as given here in order to achieve a fully - // specialized result. If that is the case, we might consider implementing - // a single combined pass that makes all of the relevant changes and - // iterates to convergence. - // - constructSSA(irModule); - // + // We clean up the usages of resource values here. + specializeResourceUsage(compileRequest, targetRequest, irModule); specializeFuncsForBufferLoadArgs(compileRequest, targetRequest, irModule); - specializeResourceParameters(compileRequest, targetRequest, irModule); + + // + simplifyIR(irModule); // For GLSL targets, we also want to specialize calls to functions that // takes array parameters if possible, to avoid performance issues on @@ -487,6 +465,7 @@ Result linkAndOptimizeIR( if (isKhronosTarget(targetRequest)) { specializeArrayParameters(compileRequest, targetRequest, irModule); + simplifyIR(irModule); } #if 0 @@ -675,6 +654,17 @@ Result linkAndOptimizeIR( break; } + // Legalize `ImageSubscript` for GLSL. + switch (target) + { + case CodeGenTarget::GLSL: + { + legalizeImageSubscriptForGLSL(irModule); + } + break; + default: + break; + } switch( target ) { @@ -712,11 +702,16 @@ Result linkAndOptimizeIR( // functions, so there might still be invalid code in // our IR module. // - // To clean up the code, we will apply a fairly general - // dead-code-elimination (DCE) pass that only retains - // whatever code is "live." + // We run IR simplification passes again to clean things up. // - eliminateDeadCode(irModule); + simplifyIR(irModule); + + if (isKhronosTarget(targetRequest)) + { + // As a fallback, if the above specialization steps failed to remove resource type parameters, we will + // inline the functions in question to make sure we can produce valid GLSL. + performGLSLResourceReturnFunctionInlining(irModule); + } #if 0 dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); #endif @@ -725,8 +720,7 @@ Result linkAndOptimizeIR( // Lower all bit_cast operations on complex types into leaf-level // bit_cast on basic types. lowerBitCast(targetRequest, irModule); - eliminateDeadCode(irModule); - + simplifyIR(irModule); // We include one final step to (optionally) dump the IR and validate // it after all of the optimization passes are complete. This should diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 285e5100c..9ed5249fe 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -81,7 +81,7 @@ struct DeadCodeEliminationContext // dive into the task of actually finding all // the live code in a module. // - void processModule() + bool processModule() { // First of all, we know that the root module instruction // should be considered as live, because otherwise @@ -174,11 +174,12 @@ struct DeadCodeEliminationContext // recursively and eliminate those that are "dead" by // virtue of not having been found live. // - eliminateDeadInstsRec(module->getModuleInst()); + return eliminateDeadInstsRec(module->getModuleInst()); } - void eliminateDeadInstsRec(IRInst* inst) + bool eliminateDeadInstsRec(IRInst* inst) { + bool changed = false; // Given the instruction `inst` we need to eliminate // any dead code at, or under it. // @@ -192,6 +193,7 @@ struct DeadCodeEliminationContext // mark the parent of a live instruction as live). // inst->removeAndDeallocate(); + changed = true; } else { @@ -208,9 +210,10 @@ struct DeadCodeEliminationContext for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next ) { next = child->getNextInst(); - eliminateDeadInstsRec(child); + changed |= eliminateDeadInstsRec(child); } } + return changed; } // Now we come to the decision procedure we put off before: @@ -336,7 +339,7 @@ struct DeadCodeEliminationContext // is straighforward. We set up the context object // and then defer to it for the real work. // -void eliminateDeadCode( +bool eliminateDeadCode( IRModule* module, IRDeadCodeEliminationOptions const& options) { @@ -344,7 +347,7 @@ void eliminateDeadCode( context.module = module; context.options = options; - context.processModule(); + return context.processModule(); } } diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h index 007905486..aac10a20a 100644 --- a/source/slang/slang-ir-dce.h +++ b/source/slang/slang-ir-dce.h @@ -17,8 +17,8 @@ namespace Slang /// "global" dead code elimination (DCE), such as removing /// types that are unused, functions that are never called, /// etc. - /// - void eliminateDeadCode( + /// Returns true if changed. + bool eliminateDeadCode( IRModule* module, IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions()); } diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index bb40ecca9..83b75033a 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -8,8 +8,141 @@ namespace Slang { +int getIRVectorElementSize(IRType* type) +{ + if (type->getOp() != kIROp_VectorType) + return 1; + return (int)(as<IRIntLit>(as<IRVectorType>(type)->getElementCount())->value.intVal); +} +IRType* getIRVectorBaseType(IRType* type) +{ + if (type->getOp() != kIROp_VectorType) + return type; + return as<IRVectorType>(type)->getElementType(); +} -// +void legalizeImageSubscriptStoreForGLSL(IRBuilder& builder, IRInst* storeInst) +{ + builder.setInsertBefore(storeInst); + auto imageSubscript = as<IRImageSubscript>(storeInst->getOperand(0)); + assert(imageSubscript); + auto imageElementType = cast<IRPtrTypeBase>(imageSubscript->getDataType())->getValueType(); + auto coordType = imageSubscript->getCoord()->getDataType(); + auto coordVectorSize = getIRVectorElementSize(coordType); + if (coordVectorSize != 1) + { + coordType = builder.getVectorType( + builder.getIntType(), builder.getIntValue(builder.getIntType(), coordVectorSize)); + } + else + { + coordType = builder.getIntType(); + } + auto legalizedCoord = imageSubscript->getCoord(); + if (coordType != imageSubscript->getCoord()->getDataType()) + { + legalizedCoord = builder.emitConstructorInst(coordType, 1, &legalizedCoord); + } + switch (storeInst->getOp()) + { + case kIROp_Store: + { + auto newValue = storeInst->getOperand(1); + if (getIRVectorElementSize(imageElementType) != 4) + { + auto vectorBaseType = getIRVectorBaseType(imageElementType); + newValue = builder.emitConstructorInst( + builder.getVectorType( + vectorBaseType, builder.getIntValue(builder.getIntType(), 4)), + 1, + &newValue); + } + auto imageStore = builder.emitImageStore( + builder.getVoidType(), + imageSubscript->getImage(), + legalizedCoord, + newValue); + storeInst->replaceUsesWith(imageStore); + storeInst->removeAndDeallocate(); + if (!imageSubscript->hasUses()) + { + imageSubscript->removeAndDeallocate(); + } + } + break; + case kIROp_SwizzledStore: + { + auto swizzledStore = cast<IRSwizzledStore>(storeInst); + // Here we assume the imageElementType is already lowered into float4/uint4 types from any + // user-defined type. + assert(imageElementType->getOp() == kIROp_VectorType); + auto originalValue = builder.emitImageLoad(imageElementType, imageSubscript->getImage(), legalizedCoord); + Array<IRInst*, 4> indices; + for (UInt i = 0; i < swizzledStore->getElementCount(); i++) + { + indices.add(swizzledStore->getElementIndex(i)); + } + auto newValue = builder.emitSwizzleSet( + imageElementType, + originalValue, + swizzledStore->getSource(), + swizzledStore->getElementCount(), + indices.getBuffer()); + if (getIRVectorElementSize(imageElementType) != 4) + { + auto vectorBaseType = getIRVectorBaseType(imageElementType); + newValue = builder.emitConstructorInst( + builder.getVectorType( + vectorBaseType, builder.getIntValue(builder.getIntType(), 4)), + 1, + &newValue); + } + auto imageStore = builder.emitImageStore( + builder.getVoidType(), imageSubscript->getImage(), legalizedCoord, newValue); + storeInst->replaceUsesWith(imageStore); + storeInst->removeAndDeallocate(); + if (!imageSubscript->hasUses()) + { + imageSubscript->removeAndDeallocate(); + } + } + break; + default: + break; + } +} + +void legalizeImageSubscriptForGLSL(IRModule* module) +{ + SharedIRBuilder shared(module); + IRBuilder builder(shared); + for (auto globalInst : module->getModuleInst()->getChildren()) + { + auto func = as<IRFunc>(globalInst); + if (!func) + continue; + for (auto block : func->getBlocks()) + { + auto inst = block->getFirstInst(); + IRInst* next; + for ( ; inst; inst = next) + { + next = inst->getNextInst(); + switch (inst->getOp()) + { + case kIROp_Store: + case kIROp_SwizzledStore: + if (inst->getOperand(0)->getOp() == kIROp_ImageSubscript) + { + legalizeImageSubscriptStoreForGLSL(builder, inst); + } + } + } + } + } +} + + // // Legalization of entry points for GLSL: // @@ -1341,6 +1474,16 @@ void legalizeEntryPointParameterForGLSL( } } + if (stage == Stage::Geometry) + { + // If the user provided no parameters with a input primitive type qualifier, we + // default to `triangle`. + if (!func->findDecoration<IRGeometryInputPrimitiveTypeDecoration>()) + { + builder->addDecoration(func, kIROp_TriangleInputPrimitiveTypeDecoration); + } + } + // There *can* be multiple streamout parameters, to an entry point (points if nothing else) { IRType* type = pp->getFullType(); diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h index 0c6bf5196..920715be2 100644 --- a/source/slang/slang-ir-glsl-legalize.h +++ b/source/slang/slang-ir-glsl-legalize.h @@ -20,4 +20,6 @@ void legalizeEntryPointsForGLSL( DiagnosticSink* sink, GLSLExtensionTracker* glslExtensionTracker); +void legalizeImageSubscriptForGLSL(IRModule* module); + } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index a8347438b..ed4b34875 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -1,6 +1,8 @@ // slang-ir-inline.cpp #include "slang-ir-inline.h" +#include "slang-ir-ssa-simplification.h" + // This file provides general facilities for inlining function calls. // @@ -33,17 +35,18 @@ struct InliningPassBase } /// Consider all the call sites in the module for inliing - void considerAllCallSites() + bool considerAllCallSites() { - considerAllCallSitesRec(m_module->getModuleInst()); + return considerAllCallSitesRec(m_module->getModuleInst()); } /// Consider all call sites at or under `inst` for inlining - void considerAllCallSitesRec(IRInst* inst) + bool considerAllCallSitesRec(IRInst* inst) { + bool changed = false; if( auto call = as<IRCall>(inst) ) { - considerCallSite(call); + changed = considerCallSite(call); } // Note: we defensively iterate through the child instructions @@ -54,8 +57,9 @@ struct InliningPassBase for( auto child = inst->getFirstChild(); child; child = next ) { next = child->getNextInst(); - considerAllCallSitesRec(child); + changed |= considerAllCallSitesRec(child); } + return changed; } // In order to inline a call site, we need certain information @@ -93,7 +97,7 @@ struct InliningPassBase // basic proces of considering a call site for inlining. /// Consider the given `call` site, and possibly inline it. - void considerCallSite(IRCall* call) + bool considerCallSite(IRCall* call) { // We start by checking if inlining would even be possible, // since doing so collects information about the call site @@ -104,7 +108,7 @@ struct InliningPassBase // CallSiteInfo callSite; if(!canInline(call, callSite)) - return; + return false; // If we've decided that we *can* inline the given call // site, we next need to check if we *should*. The rules @@ -112,13 +116,14 @@ struct InliningPassBase // so `shouldInline` is a virtual method. // if(!shouldInline(callSite)) - return; + return false; // Finally, if we both *can* and *should* inline the // given call site, we hand off the a worker routine // that does the meat of the work. // inlineCallSite(callSite); + return true; } // Every subclas of `InliningPassBase` should provide its own @@ -313,12 +318,11 @@ struct InliningPassBase } // For now, our inlining pass only handles the case where - // the callee is a "trivial" function, which can support - // a very simple approach to inlining. - // - if( isTrivialFunc(callee) ) + // the callee is a "single-return" function, which means the callee + // function contains only one return at the end of the body. + if (isSingleReturnFunc(callee)) { - inlineTrivialFuncBody(callSite, &env, &builder); + inlineSingleReturnFuncBody(callSite, &env, &builder); } else { @@ -329,26 +333,34 @@ struct InliningPassBase } } - /// Check if `func` represents a trivial single-block callee that can be inlined simply - bool isTrivialFunc(IRFunc* func) + /// Check if `func` represents a simple callee that has only a single `return`. + bool isSingleReturnFunc(IRFunc* func) { - // The function must have a single bocy block to be trivial. - // auto firstBlock = func->getFirstBlock(); - if( firstBlock->getNextBlock() ) - return false; // If the body block is decorated (for some reason), then the function is non-trivial. // if( firstBlock->getFirstDecoration() ) return false; - // If the body block terminates in something other than a `return` then the function is non-trivial. - // - auto terminator = firstBlock->getTerminator(); - if( !as<IRReturn>(terminator) ) - return false; - + // If the body has more than one returns, we cannot inline it now. + bool returnFound = false; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_ReturnVal || inst->getOp() == kIROp_ReturnVoid) + { + // If the return is not at the end of the block, we cannot handle it. + if (inst != block->getTerminator()) + return false; + // If there is already a return found, this function cannot be simple. + if (returnFound) + return false; + returnFound = true; + } + } + } return true; } @@ -374,13 +386,8 @@ struct InliningPassBase // serialization. // // For now this punts on this, and just assumes [__unsafeForceInlineEarly] is not in user code. - static IRInst* _cloneInstWithSourceLoc(CallSiteInfo const& callSite, - IRCloneEnv* env, - IRBuilder* builder, - IRInst* inst) + static void _setSourceLoc(IRInst* clonedInst, IRInst* srcInst, CallSiteInfo const& callSite) { - IRInst* clonedInst = cloneInst(env, builder, inst); - SourceLoc sourceLoc; if (callSite.call->sourceLoc.isValid()) @@ -388,67 +395,129 @@ struct InliningPassBase // Default to using the source loc at the call site sourceLoc = callSite.call->sourceLoc; } - else if (inst->sourceLoc.isValid()) + else if (srcInst->sourceLoc.isValid()) { // If we don't have that copy the inst being cloned sourceLoc - sourceLoc = inst->sourceLoc; + sourceLoc = srcInst->sourceLoc; } clonedInst->sourceLoc = sourceLoc; + } + static IRInst* _cloneInstWithSourceLoc(CallSiteInfo const& callSite, + IRCloneEnv* env, + IRBuilder* builder, + IRInst* inst) + { + IRInst* clonedInst = cloneInst(env, builder, inst); + _setSourceLoc(clonedInst, inst, callSite); return clonedInst; } - /// Inline the body of the callee for `callSite`, where the callee is trivial as tested by `isTrivialFunc` - void inlineTrivialFuncBody(CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder) + /// Inline the body of the callee for `callSite`, where the callee has a single return. + void inlineSingleReturnFuncBody( + CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder) { auto call = callSite.call; auto callee = callSite.callee; - auto firstBlock = callee->getFirstBlock(); - // We know that the callee has a single block, so if we encounter + // We know that the callee has a single return block, so if we encounter // a `returnVal` instruction then it must be the one and only - // return point for the block, and its operand will be the value - // the calee returns. + // return point for the function, and its operand will be the value + // the callee returns. // IRInst* returnedValue = nullptr; - // We will loop over the instructions of the one and only block, - // and clone each of them appropriately. - // - for( auto inst : firstBlock->getChildren() ) + // Break the basic block containing the call inst into two basic blocks. + auto callerBlock = callSite.call->getParent(); + builder->setInsertInto(callerBlock->getParent()); + auto afterBlock = builder->createBlock(); + + // Many operations (e.g. `cloneInst`) has define-before-use assumptions on the IR. + // It is important to make sure we keep the ordering of blocks by inserting the + // second half of the basic block right after `callerBlock`. + afterBlock->insertAfter(callerBlock); + afterBlock->sourceLoc = callSite.call->getNextInst()->sourceLoc; + // Move all insts after the call in `callerBlock` to `afterBlock`. { - switch( inst->getOp() ) + auto inst = callSite.call->getNextInst(); + while (inst) { - default: - // The default value is to clone the instruction using - // the existing cloning infrastructure and the `env` - // we have already set up. - // - // SourceLoc information is copied if there is appropriate data available. - _cloneInstWithSourceLoc(callSite, env, builder, inst); - break; - - case kIROp_Param: - // Parameters can be completely ignored in the single-block - // case, because they have all been replaced via `env`. - break; - - case kIROp_ReturnVoid: - // A return with no operand can be ignored, since a return - // from the inlined call should just continue after the - // call site. - // - break; - - case kIROp_ReturnVal: - // A return with a value is similar to `returnVoid` except - // that we need to note the (clone of the) value being - // returned, so that we can use it to replace the value - // of the original call. - // - returnedValue = findCloneForOperand(env, inst->getOperand(0)); - break; + auto next = inst->getNextInst(); + inst->removeFromParent(); + inst->insertAtEnd(afterBlock); + inst = next; + } + } + + List<IRBlock*> clonedBlocks; + for (auto calleeBlock : callee->getBlocks()) + { + auto clonedBlock = builder->createBlock(); + clonedBlock->insertBefore(afterBlock); + _setSourceLoc(clonedBlock, calleeBlock, callSite); + env->mapOldValToNew[calleeBlock] = clonedBlock; + } + + // Insert a branch into the cloned first block at the end of `callerBlock`. + builder->setInsertInto(callerBlock); + auto newBranch = builder->emitBranch(as<IRBlock>(env->mapOldValToNew[callee->getFirstBlock()].GetValue())); + _setSourceLoc(newBranch, call, callSite); + // Clone all basic blocks over to the call site. + bool isFirstBlock = true; + for (auto calleeBlock : callee->getBlocks()) + { + auto clonedBlock = env->mapOldValToNew[calleeBlock].GetValue(); + builder->setInsertInto(clonedBlock); + // We will loop over the instructions of the each block, + // and clone each of them appropriately. + // + for (auto inst : calleeBlock->getChildren()) + { + if (inst->getOp() == kIROp_Param) + { + // Parameters in the first block can be completely ignored + // because they have all been replaced via `env`. + if (isFirstBlock) + { + continue; + } + } + + switch (inst->getOp()) + { + default: + // The default value is to clone the instruction using + // the existing cloning infrastructure and the `env` + // we have already set up. + // + // SourceLoc information is copied if there is appropriate data available. + _cloneInstWithSourceLoc(callSite, env, builder, inst); + break; + + case kIROp_ReturnVoid: + // A return with no operand is replaced with a branch into `afterBlock` + // to return the control flow to the location after the original `call`. + { + auto returnBranch = builder->emitBranch(afterBlock); + _setSourceLoc(returnBranch, inst, callSite); + } + break; + + case kIROp_ReturnVal: + // A return with a value is similar to `returnVoid` except + // that we need to note the (clone of the) value being + // returned, so that we can use it to replace the value + // of the original call. + // + { + auto returnBranch = builder->emitBranch(afterBlock); + _setSourceLoc(returnBranch, inst, callSite); + returnedValue = findCloneForOperand(env, inst->getOperand(0)); + } + break; + } } + isFirstBlock = false; } // If there was a `returnVal` instruction that established @@ -492,4 +561,73 @@ void performMandatoryEarlyInlining(IRModule* module) pass.considerAllCallSites(); } + + // Defined in slang-ir-specialize-resource.cpp +bool isResourceType(IRType* type); +bool isIllegalGLSLParameterType(IRType* type); + + /// An inlining pass that inlines calls functions that returns resources. + /// This is needed for glsl targets. +struct GLSLResourceReturnFunctionInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + GLSLResourceReturnFunctionInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const& info) + { + if (isResourceType(info.callee->getResultType())) + { + return true; + } + for (auto param : info.callee->getParams()) + { + if (isIllegalGLSLParameterType(param->getDataType())) + return true; + auto outType = as<IROutTypeBase>(param->getDataType()); + if (!outType) + continue; + auto outValueType = outType->getValueType(); + if (isResourceType(outValueType)) + return true; + } + return false; + } +}; + +void performGLSLResourceReturnFunctionInlining(IRModule* module) +{ + GLSLResourceReturnFunctionInliningPass pass(module); + bool changed = true; + + while (changed) + { + changed = pass.considerAllCallSites(); + simplifyIR(module); + } +} + +struct CustomInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + CustomInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const&) + { + return true; + } +}; + +bool inlineCall(IRCall* call) +{ + CustomInliningPass pass(call->getModule()); + return pass.considerCallSite(call); +} + + } // namespace Slang diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index 322b4c855..8ac23f6b0 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -4,7 +4,14 @@ namespace Slang { struct IRModule; + struct IRCall; /// Inline any call sites to functions marked `[unsafeForceInlineEarly]` void performMandatoryEarlyInlining(IRModule* module); + + /// Inline calls to functions that returns a resource/sampler via either return value or output parameter. + void performGLSLResourceReturnFunctionInlining(IRModule* module); + + /// Inline a specific call. + bool inlineCall(IRCall* call); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 0fca118d1..a3486ee68 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -296,6 +296,11 @@ INST(getAddr, getAddr, 1, 0) // "Subscript" an image at a pixel coordinate to get pointer INST(ImageSubscript, imageSubscript, 2, 0) +// Load from an Image. +INST(ImageLoad, imageLoad, 2, 0) +// Store into an Image. +INST(ImageStore, imageStore, 3, 0) + // Load (almost) arbitrary-type data from a byte-address buffer // // %dst = byteAddressBufferLoad(%buffer, %offset) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9b50047de..09a3bdbb3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1368,6 +1368,27 @@ struct IRGetAddress : IRInst IR_LEAF_ISA(getAddr); }; +struct IRImageSubscript : IRInst +{ + IR_LEAF_ISA(ImageSubscript); + IRInst* getImage() { return getOperand(0); } + IRInst* getCoord() { return getOperand(1); } +}; + +struct IRImageLoad : IRInst +{ + IR_LEAF_ISA(ImageLoad); + IRInst* getImage() { return getOperand(0); } + IRInst* getCoord() { return getOperand(1); } +}; + +struct IRImageStore : IRInst +{ + IR_LEAF_ISA(ImageStore); + IRInst* getImage() { return getOperand(0); } + IRInst* getCoord() { return getOperand(1); } + IRInst* getValue() { return getOperand(2); } +}; // Terminators struct IRReturn : IRTerminatorInst @@ -2413,6 +2434,17 @@ public: IRInst* dstPtr, IRInst* srcVal); + IRInst* emitImageLoad( + IRType* type, + IRInst* image, + IRInst* coord); + + IRInst* emitImageStore( + IRType* type, + IRInst* image, + IRInst* coord, + IRInst* value); + IRInst* emitFieldExtract( IRType* type, IRInst* base, diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index 159cf9abc..463de29ca 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -109,6 +109,41 @@ struct SCCPContext } }; + static bool isEvaluableOpCode(IROp op) + { + switch (op) + { + case kIROp_IntLit: + case kIROp_BoolLit: + case kIROp_FloatLit: + case kIROp_StringLit: + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Not: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Less: + case kIROp_Greater: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitNot: + case kIROp_BitCast: + case kIROp_Construct: + case kIROp_Select: + return true; + default: + return false; + } + } + // If we imagine a variable (actually an SSA phi node...) that // might be assigned lattice value A at one point in the code, // and lattice value B at another point, we need a way to @@ -204,20 +239,29 @@ struct SCCPContext break; } - // We might be asked for the lattice value of an instruction - // not contained in the current function. When that happens, - // we will treat it as having potentially any value, rather - // than the default of none. - // - auto parentBlock = as<IRBlock>(inst->getParent()); - if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny(); - - // Once the special cases are dealt with, we can look up in - // the dictionary and just return the value we get from it, - // or default to the `None` (empty set) case. + // Look up in the dictionary and just return the value we get from it. LatticeVal latticeVal; if(mapInstToLatticeVal.TryGetValue(inst, latticeVal)) return latticeVal; + + // If we can't find the value from dictionary, we want to return None if this is a value + // in the same function as the one we are working with right now. If it is defined + // elsewhere, we return Any. + auto parentBlock = as<IRBlock>(inst->getParent()); + bool isProcessingGlobalScope = (code == nullptr); + if (!parentBlock && isProcessingGlobalScope) + { + // We are folding constant in the global scope, continue registering the inst as Any. + } + else + { + // If we are processing a function and asked for the lattice value of an instruction + // not contained in the current function, we will treat it as having potentially any + // value, rather than the default of none. + // + if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny(); + } + return LatticeVal::getNone(); } @@ -228,6 +272,460 @@ struct SCCPContext IRBuilder builderStorage; IRBuilder* getBuilder() { return &builderStorage; } + // LatticeVal constant evaluation methods. +#define SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v) \ + switch (v.flavor) \ + { \ + case LatticeVal::Flavor::None: \ + return LatticeVal::getNone(); \ + case LatticeVal::Flavor::Any: \ + return LatticeVal::getAny(); \ + default: \ + break; \ + } + + LatticeVal evalConstruct(IRType* type, LatticeVal v0) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto irConstant = as<IRConstant>(v0.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + switch (irConstant->getOp()) + { + case kIROp_FloatLit: + resultVal = + getBuilder()->getIntValue(type, (IRIntegerValue)irConstant->value.floatVal); + break; + case kIROp_IntLit: + case kIROp_BoolLit: + { + IRIntegerValue intVal = irConstant->value.intVal; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_UInt8Type: + intVal = intVal & 0xFF; + break; + case kIROp_Int16Type: + case kIROp_UInt16Type: + intVal = intVal & 0xFFFF; + break; + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_BoolType: + intVal = intVal & 0xFFFFFFFF; + break; + default: + break; + } + resultVal = getBuilder()->getIntValue(type, (IRIntegerValue)intVal); + } + break; + default: + return LatticeVal::getAny(); + } + break; + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + switch (irConstant->getOp()) + { + case kIROp_FloatLit: + resultVal = getBuilder()->getFloatValue( + type, (IRFloatingPointValue)irConstant->value.floatVal); + break; + case kIROp_IntLit: + case kIROp_BoolLit: + resultVal = getBuilder()->getFloatValue( + type, (IRFloatingPointValue)irConstant->value.intVal); + break; + default: + return LatticeVal::getAny(); + } + break; + case kIROp_BoolType: + switch (irConstant->getOp()) + { + case kIROp_FloatLit: + resultVal = getBuilder()->getBoolValue(irConstant->value.floatVal != 0); + break; + case kIROp_IntLit: + case kIROp_BoolLit: + { + resultVal = getBuilder()->getBoolValue(irConstant->value.intVal != 0); + } + break; + default: + return LatticeVal::getAny(); + } + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + template<typename TIntFunc, typename TFloatFunc> + LatticeVal evalBinaryImpl( + IRType* type, + LatticeVal v0, + LatticeVal v1, + const TIntFunc& intFunc, + const TFloatFunc& floatFunc) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1) + auto c1 = as<IRConstant>(v1.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_BoolType: + resultVal = getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal)); + break; + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + resultVal = getBuilder()->getFloatValue(type, floatFunc(c0->value.floatVal, c1->value.floatVal)); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + template <typename TIntFunc> + LatticeVal evalBinaryIntImpl( + IRType* type, + LatticeVal v0, + LatticeVal v1, + const TIntFunc& intFunc) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1) + auto c1 = as<IRConstant>(v1.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_BoolType: + resultVal = + getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal)); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + template <typename TIntFunc> + LatticeVal evalUnaryIntImpl( + IRType* type, LatticeVal v0, const TIntFunc& intFunc) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_BoolType: + resultVal = + getBuilder()->getIntValue(type, intFunc(c0->value.intVal)); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + template <typename TIntFunc, typename TFloatFunc> + LatticeVal evalComparisonImpl( + IRType* type, + LatticeVal v0, + LatticeVal v1, + const TIntFunc& intFunc, + const TFloatFunc& floatFunc) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1) + auto c1 = as<IRConstant>(v1.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_BoolType: + resultVal = + getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal)); + break; + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + resultVal = + getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal)); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + LatticeVal evalAdd(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 + c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 + c1; }); + } + LatticeVal evalSub(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 - c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 - c1; }); + } + LatticeVal evalMul(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 * c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 * c1; }); + } + LatticeVal evalDiv(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 / c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 / c1; }); + } + LatticeVal evalEql(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 == c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 == c1; }); + } + LatticeVal evalNeq(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 != c1; }); + } + LatticeVal evalGeq(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >= c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 >= c1; }); + } + LatticeVal evalLeq(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 <= c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 <= c1; }); + } + LatticeVal evalGreater(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 > c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 > c1; }); + } + LatticeVal evalLess(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalComparisonImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 < c1; }, + [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 < c1; }); + } + LatticeVal evalAnd(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, + v0, + v1, + [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 && c1 != 0; }); + } + LatticeVal evalOr(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 || c1 != 0; }); + } + LatticeVal evalNot(IRType* type, LatticeVal v0) + { + return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return c0 == 0; }); + } + LatticeVal evalBitAnd(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 & c1; }); + } + LatticeVal evalBitOr(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 | c1; }); + } + LatticeVal evalBitNot(IRType* type, LatticeVal v0) + { + return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return ~c0; }); + } + LatticeVal evalBitXor(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 ^ c1; }); + } + LatticeVal evalLsh(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 << c1; }); + } + LatticeVal evalRsh(IRType* type, LatticeVal v0, LatticeVal v1) + { + return evalBinaryIntImpl( + type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >> c1; }); + } + LatticeVal evalNeg(IRType* type, LatticeVal v0) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + resultVal = getBuilder()->getIntValue(type, -c0->value.intVal); + break; + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + resultVal = getBuilder()->getFloatValue(type, -c0->value.floatVal); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + LatticeVal evalBitCast(IRType* type, LatticeVal v0) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + IRInst* resultVal = nullptr; + switch (type->getOp()) + { + case kIROp_Int64Type: + case kIROp_UInt64Type: + resultVal = getBuilder()->getIntValue(type, c0->value.intVal); + break; + case kIROp_IntType: + case kIROp_UIntType: + { + float val = (float)c0->value.floatVal; + uint32_t intVal = (uint32_t)FloatAsInt(val); + resultVal = getBuilder()->getIntValue(type, intVal); + } + break; + case kIROp_FloatType: + { + uint32_t val = (uint32_t)c0->value.intVal; + float floatVal = IntAsFloat((int)val); + resultVal = getBuilder()->getFloatValue(type, floatVal); + } + break; + case kIROp_DoubleType: + resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(c0->value.intVal)); + break; + default: + break; + } + if (!resultVal) + return LatticeVal::getAny(); + return LatticeVal::getConstant(resultVal); + } + + LatticeVal evalSelect(LatticeVal v0, LatticeVal v1, LatticeVal v2) + { + SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) + auto c0 = as<IRConstant>(v0.value); + return c0->value.intVal != 0 ? v1 : v2; + } + // In order to perform constant folding, we need to be able to // interpret an instruction over the lattice values. // @@ -245,11 +743,17 @@ struct SCCPContext case kIROp_BoolLit: return LatticeVal::getConstant(inst); - // TODO: we might also want to special-case certain + // We might also want to special-case certain // instructions where we shouldn't bother trying to // constant-fold them and should just default to the // `Any` value right away. - + case kIROp_Call: + case kIROp_ByteAddressBufferLoad: + case kIROp_ByteAddressBufferStore: + case kIROp_Alloca: + case kIROp_Store: + case kIROp_Load: + return LatticeVal::getAny(); default: break; } @@ -288,10 +792,116 @@ struct SCCPContext // `None` inputs as producing `Any` to make sure we don't // optimize the code based on non-obvious assumptions. // - // For now we aren't implementing *any* folding logic here, - // for simplicity. This is the right place to add folding - // optimizations if/when we need them. - // + // For now we implement only basic folding operations for + // scalar values. + if (!as<IRBasicType>(inst->getDataType())) + return LatticeVal::getAny(); + + switch (inst->getOp()) + { + case kIROp_Construct: + return evalConstruct(inst->getDataType(), getLatticeVal(inst->getOperand(0))); + case kIROp_Add: + return evalAdd( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Sub: + return evalSub( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Mul: + return evalMul( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Div: + return evalDiv( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Eql: + return evalEql( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Neq: + return evalNeq( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Greater: + return evalGreater( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Less: + return evalLess( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Leq: + return evalLeq( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Geq: + return evalGeq( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_And: + return evalAnd( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Or: + return evalOr( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Not: + return evalNot(inst->getDataType(), getLatticeVal(inst->getOperand(0))); + case kIROp_BitAnd: + return evalBitAnd( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_BitOr: + return evalBitOr( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_BitNot: + return evalBitNot(inst->getDataType(), getLatticeVal(inst->getOperand(0))); + case kIROp_BitXor: + return evalBitXor( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_BitCast: + return evalBitCast(inst->getDataType(), getLatticeVal(inst->getOperand(0))); + case kIROp_Neg: + return evalNeg(inst->getDataType(), getLatticeVal(inst->getOperand(0))); + case kIROp_Lsh: + return evalLsh( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Rsh: + return evalRsh( + inst->getDataType(), + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1))); + case kIROp_Select: + return evalSelect( + getLatticeVal(inst->getOperand(0)), + getLatticeVal(inst->getOperand(1)), + getLatticeVal(inst->getOperand(2))); + default: + break; + } // A safe default is to assume that every instruction not // handled by one of the cases above could produce *any* @@ -567,10 +1177,60 @@ struct SCCPContext } } + // Run the constant folding on global scope only. + bool applyOnGlobalScope(IRModule* module) + { + builderStorage.init(shared->sharedBuilder); + for (auto child : module->getModuleInst()->getChildren()) + { + // Only consider evaluable opcodes. + if (!isEvaluableOpCode(child->getOp())) + continue; + + updateValueForInst(child); + } + while (ssaWorkList.getCount()) + { + auto inst = ssaWorkList[0]; + ssaWorkList.fastRemoveAt(0); + // Only consider evaluable opcodes and insts at global scope. + if (!isEvaluableOpCode(inst->getOp()) || inst->getParent() != module->getModuleInst()) + continue; + updateValueForInst(inst); + } + + bool changed = false; + // Replace the insts with their values. + List<IRInst*> instsToRemove; + for (auto child : module->getModuleInst()->getChildren()) + { + if (!isEvaluableOpCode(child->getOp())) + continue; + + auto latticeVal = getLatticeVal(child); + if (latticeVal.flavor == LatticeVal::Flavor::Constant && latticeVal.value != child) + { + child->replaceUsesWith(latticeVal.value); + instsToRemove.add(child); + } + } + + if (instsToRemove.getCount()) + { + changed = true; + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + // Rebuild global value map. + builderStorage.getSharedBuilder()->deduplicateAndRebuildGlobalNumberingMap(); + } + return changed; + } + // The `apply()` function will run the full algorithm. // - void apply() + bool apply() { + bool changed = false; // We start with the busy-work of setting up our IR builder. // builderStorage.init(shared->sharedBuilder); @@ -733,6 +1393,9 @@ struct SCCPContext } } + if (instsToRemove.getCount() != 0) + changed = true; + // Once we've replaced the uses of instructions that evaluate // to constants, we make a second pass to remove the instructions // themselves (or at least those without side effects). @@ -786,6 +1449,7 @@ struct SCCPContext builder->setInsertBefore(terminator); builder->emitBranch(target); terminator->removeAndDeallocate(); + changed = true; } } else if(auto condBranchInst = as<IRConditionalBranch>(terminator)) @@ -800,6 +1464,7 @@ struct SCCPContext builder->setInsertBefore(terminator); builder->emitBranch(target); terminator->removeAndDeallocate(); + changed = true; } } @@ -911,38 +1576,52 @@ struct SCCPContext builder->emitUnreachable(); } } + return changed; } }; -static void applySparseConditionalConstantPropagationRec( - SharedSCCPContext* shared, +static bool applySparseConditionalConstantPropagationRec( + const SCCPContext& globalContext, IRInst* inst) { + bool changed = false; if( auto code = as<IRGlobalValueWithCode>(inst) ) { if( code->getFirstBlock() ) { SCCPContext context; - context.shared = shared; + context.shared = globalContext.shared; context.code = code; - context.apply(); + context.mapInstToLatticeVal = globalContext.mapInstToLatticeVal; + changed |= context.apply(); } } for( auto childInst : inst->getDecorationsAndChildren() ) { - applySparseConditionalConstantPropagationRec(shared, childInst); + changed |= applySparseConditionalConstantPropagationRec(globalContext, childInst); } + return changed; } -void applySparseConditionalConstantPropagation( +bool applySparseConditionalConstantPropagation( IRModule* module) { SharedSCCPContext shared; shared.module = module; shared.sharedBuilder.init(module); + shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + // First we fold constants at global scope. + SCCPContext globalContext; + globalContext.shared = &shared; + globalContext.code = nullptr; + bool changed = globalContext.applyOnGlobalScope(module); + + // Now run recursive SCCP passes on each child code block. + changed |= applySparseConditionalConstantPropagationRec(globalContext, module->getModuleInst()); - applySparseConditionalConstantPropagationRec(&shared, module->getModuleInst()); + return changed; } } diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h index b557eefe3..06c5769c8 100644 --- a/source/slang/slang-ir-sccp.h +++ b/source/slang/slang-ir-sccp.h @@ -12,7 +12,8 @@ namespace Slang /// also eliminates conditional branches where the condition will /// always evaluate to a constant (which can lead to entire blocks /// becoming dead code) - void applySparseConditionalConstantPropagation( + /// Returns true if IR is changed. + bool applySparseConditionalConstantPropagation( IRModule* module); } diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp new file mode 100644 index 000000000..af0e7c0ce --- /dev/null +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -0,0 +1,82 @@ +#include "slang-ir-simplify-cfg.h" + +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +bool processFunc(IRFunc* func) +{ + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return false; + + bool changed = false; + + List<IRBlock*> workList; + HashSet<IRBlock*> processedBlock; + workList.add(func->getFirstBlock()); + while (workList.getCount()) + { + auto block = workList.getFirst(); + workList.fastRemoveAt(0); + while (block) + { + // If `block` does not end with an unconditional branch, bail. + if (block->getTerminator()->getOp() != kIROp_unconditionalBranch) + break; + auto branch = as<IRUnconditionalBranch>(block->getTerminator()); + auto successor = branch->getTargetBlock(); + // Only perform the merge if `block` is the only predecessor of `successor`. + // We also need to make sure not to merge a block that serves as the + // merge point in CFG. Such blocks will have more than one use. + if (successor->hasMoreThanOneUse()) + break; + changed = true; + Index paramIndex = 0; + auto inst = successor->getFirstDecorationOrChild(); + while (inst) + { + auto next = inst->getNextInst(); + if (inst->getOp() == kIROp_Param) + { + inst->replaceUsesWith(branch->getArg(paramIndex)); + paramIndex++; + } + else + { + inst->removeFromParent(); + inst->insertAtEnd(block); + } + inst = next; + } + branch->removeAndDeallocate(); + assert(!successor->hasUses()); + successor->removeAndDeallocate(); + } + for (auto successor : block->getSuccessors()) + { + if (processedBlock.Add(successor)) + { + workList.add(successor); + } + } + } + return changed; +} + +bool simplifyCFG(IRModule* module) +{ + bool changed = false; + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(inst)) + { + changed |= processFunc(func); + } + } + return changed; +} + +} // namespace Slang diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h new file mode 100644 index 000000000..3d8729274 --- /dev/null +++ b/source/slang/slang-ir-simplify-cfg.h @@ -0,0 +1,12 @@ +// slang-ir-simplify-cfg.h +#pragma once + +namespace Slang +{ + struct IRModule; + + /// Simplifies control flow graph by merging basic blocks that + /// forms a simple linear chain. + /// Returns true if changed. + bool simplifyCFG(IRModule* module); +} diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index eac5287e5..bcd7b494f 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -123,14 +123,14 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, // the witness table sequential ID passed in. builder->setInsertInto(newDispatchFunc); - + if (witnessTables.getCount() == 1) { // If there is only 1 case, no switch statement is necessary. builder->setInsertInto(newBlock); builder->emitBranch(defaultBlock); } - else + else if (witnessTables.getCount() > 1) { auto breakBlock = builder->emitBlock(); builder->setInsertInto(breakBlock); @@ -144,6 +144,21 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, caseBlocks.getCount(), caseBlocks.getBuffer()); } + else + { + // We have no witness tables that implements this interface. + // Just return a default value. + builder->setInsertInto(newBlock); + if (callInst->getDataType()->getOp() == kIROp_VoidType) + { + builder->emitReturn(); + } + else + { + auto defaultValue = builder->emitConstructorInst(callInst->getDataType(), 0, nullptr); + builder->emitReturn(defaultValue); + } + } // Remove old implementation. dispatchFunc->replaceUsesWith(newDispatchFunc); dispatchFunc->removeAndDeallocate(); diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index ab77ec88e..1e63be890 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -100,7 +100,7 @@ struct FunctionParameterSpecializationContext // With the basic state out of the way, let's walk // through the overall flow of the pass. // - void processModule() + bool processModule() { // We will start by initializing our IR building state. // @@ -112,6 +112,8 @@ struct FunctionParameterSpecializationContext // addCallsToWorkListRec(module->getModuleInst()); + bool changed = false; + // We will process the work list until it goes dry, // treating it like a stack of work items. // @@ -130,8 +132,10 @@ struct FunctionParameterSpecializationContext if( canSpecializeCall(call) ) { specializeCall(call); + changed = true; } } + return changed; } // Setting up the work list is a simple recursive procedure. @@ -353,6 +357,7 @@ struct FunctionParameterSpecializationContext // we need to generate a call to it, and then use the new // call as a replacement for the old one. // + SLANG_ASSERT(newFunc != oldCall->getCallee()); auto newCall = getBuilder()->emitCallInst( oldCall->getFullType(), newFunc, @@ -877,7 +882,7 @@ struct FunctionParameterSpecializationContext // is straighforward. We set up the context object // and then defer to it for the real work. // -void specializeFunctionCalls( +bool specializeFunctionCalls( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module, @@ -889,7 +894,7 @@ void specializeFunctionCalls( context.module = module; context.condition = condition; - context.processModule(); + return context.processModule(); } } // namesapce Slang diff --git a/source/slang/slang-ir-specialize-function-call.h b/source/slang/slang-ir-specialize-function-call.h index 868f9def2..4afb0526f 100644 --- a/source/slang/slang-ir-specialize-function-call.h +++ b/source/slang/slang-ir-specialize-function-call.h @@ -26,8 +26,8 @@ namespace Slang /// a specialized variant of the function that does not have /// those resource parameters (and instead, e.g, refers to the /// global shader parameters directly). - /// - void specializeFunctionCalls( + /// Returns true if any changes are made. + bool specializeFunctionCalls( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module, diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 1b3fc65f4..01e1c396c 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -6,6 +6,9 @@ #include "slang-ir-insts.h" #include "slang-ir-clone.h" +#include "slang-ir-ssa-simplification.h" + +#include "slang-ir-inline.h" namespace Slang { @@ -55,14 +58,11 @@ struct ResourceParameterSpecializationCondition : FunctionCallSpecializeConditio // For GL/Vulkan targets, we also need to specialize // any parameters that use structured or byte-addressed - // buffers. + // buffers or images with format qualifiers. // if( isKhronosTarget(targetRequest) ) { - if(as<IRHLSLStructuredBufferTypeBase>(type)) - return true; - if(as<IRByteAddressBufferTypeBase>(type)) - return true; + return isIllegalGLSLParameterType(type); } // For now, we will not treat any other parameters as @@ -84,14 +84,22 @@ struct ResourceParameterSpecializationCondition : FunctionCallSpecializeConditio } }; -void specializeResourceParameters( +bool specializeResourceParameters( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module) { + bool result = false; ResourceParameterSpecializationCondition condition; condition.targetRequest = targetRequest; - specializeFunctionCalls(compileRequest, targetRequest, module, &condition); + bool changed = true; + while (changed) + { + changed = specializeFunctionCalls(compileRequest, targetRequest, module, &condition); + simplifyIR(module); + result |= changed; + } + return result; } /// A pass to specialize resource-typed function outputs @@ -111,8 +119,13 @@ struct ResourceOutputSpecializationPass SharedIRBuilder sharedBuilder; SharedIRBuilder* getSharedBuilder() { return &sharedBuilder; } - void processModule() + // Functions that requires specialization but are currently unspecializable. + List<IRFunc*>* unspecializableFuncs; + + bool processModule() { + bool changed = false; + // We start by setting up the shared IR building state. // sharedBuilder.init(module); @@ -127,11 +140,12 @@ struct ResourceOutputSpecializationPass if(!func) continue; - processFunc(func); + changed |= processFunc(func); } + return changed; } - void processFunc(IRFunc* oldFunc) + bool processFunc(IRFunc* oldFunc) { // We don't want to waste any effort on functions that don't merit // specialization, so the first step is to identify if the function @@ -141,7 +155,7 @@ struct ResourceOutputSpecializationPass // the given function. // if(!shouldSpecializeFunc(oldFunc)) - return; + return false; // It is possible that we have a function that we *should* specialize // (based on its signature), but we *cannot* yet specialize it. @@ -201,7 +215,9 @@ struct ResourceOutputSpecializationPass // are sure we can optimize/simplify, so that the error // messages can be front-end rather than back-end errors. // - return; + newFunc->removeAndDeallocate(); + unspecializableFuncs->add(oldFunc); + return false; } // Specialization might have changed the signature of `newFunc`, @@ -265,6 +281,7 @@ struct ResourceOutputSpecializationPass { specializeCallSite(oldCall, newFunc, funcInfo); } + return true; } // With the overall flow of the pass described, we can now drill down @@ -1096,16 +1113,13 @@ struct ResourceOutputSpecializationPass // but transforming them so that the function signatures are changed makes // the challenge more explicit and thus perhaps easier to tackle. - // TODO: We probably need to update the two passes in this file so that they - // work in an iterative fashion (combined with some SSA "cleanup" on function - // bodies), because each optimization may open up opportunties for the other - // to apply. }; -void specializeResourceOutputs( +bool specializeResourceOutputs( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, - IRModule* module) + IRModule* module, + List<IRFunc*>& unspecializableFuncs) { if(isD3DTarget(targetRequest) || isKhronosTarget(targetRequest)) {} @@ -1118,14 +1132,102 @@ void specializeResourceOutputs( // of conditional in a way that doesn't involve explicitly // enumerating matching targets. // - return; + return false; } ResourceOutputSpecializationPass pass; pass.compileRequest = compileRequest; pass.targetRequest = targetRequest; pass.module = module; - pass.processModule(); + pass.unspecializableFuncs = &unspecializableFuncs; + return pass.processModule(); +} + +bool specializeResourceUsage( + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* irModule) +{ + bool result = false; + // We apply two kinds of specialization to clean up resource value usage: + // + // * Specalize call sites based on the actual resources + // that a called function will return/output. + // + // * Specialize called functions based on the actual resources + // passed as input at specific call sites. + // + // We need to run the two passes in an iterative fashion (combined with IR + // simplification passes), because each optimization may open up opportunties + // for the other to apply. + // + for (;;) + { + bool changed = true; + List<IRFunc*> unspecializableFuncs; + while (changed) + { + changed = false; + unspecializableFuncs.clear(); + // Because the legalization may depend on what target + // we are compiling for (certain things might be okay + // for D3D targets that are not okay for Vulkan), we + // pass down the target request along with the IR. + // + changed |= specializeResourceOutputs( + compileRequest, targetRequest, irModule, unspecializableFuncs); + changed |= specializeResourceParameters(compileRequest, targetRequest, irModule); + + // After specialization of function outputs, we may find that there + // are cases where opaque-typed local variables can now be eliminated + // and turned into SSA temporaries. Such optimization may enable + // the following passes to "see" and specialize more cases. + // + simplifyIR(irModule); + result |= changed; + } + if (unspecializableFuncs.getCount() == 0) + break; + + // Inline unspecializable resource output functions and then continue trying. + for (auto func : unspecializableFuncs) + { + for (auto use = func->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + auto call = as<IRCall>(user); + if (!call) + continue; + if (call->getCallee() != func) + continue; + inlineCall(call); + } + } + simplifyIR(irModule); + } + return result; +} + +bool isIllegalGLSLParameterType(IRType* type) +{ + if (as<IRUniformParameterGroupType>(type)) + return true; + if (as<IRHLSLStructuredBufferTypeBase>(type)) + return true; + if (as<IRByteAddressBufferTypeBase>(type)) + return true; + if (as<IRGLSLImageType>(type)) + return true; + if (auto texType = as<IRTextureType>(type)) + { + switch (texType->getAccess()) + { + case SLANG_RESOURCE_ACCESS_READ_WRITE: + case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: + return true; + default: + break; + } + } + return false; } } // namespace Slang diff --git a/source/slang/slang-ir-specialize-resources.h b/source/slang/slang-ir-specialize-resources.h index 62a2728bc..0f2fdc99e 100644 --- a/source/slang/slang-ir-specialize-resources.h +++ b/source/slang/slang-ir-specialize-resources.h @@ -6,6 +6,7 @@ namespace Slang class BackEndCompileRequest; class TargetRequest; struct IRModule; + struct IRType; /// Specialize calls to functions with resource-type parameters. /// @@ -17,13 +18,20 @@ namespace Slang /// those resource parameters (and instead, e.g, refers to the /// global shader parameters directly). /// - void specializeResourceParameters( + bool specializeResourceParameters( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module); - void specializeResourceOutputs( + bool specializeResourceOutputs( BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module); + + /// Combined iterative passes of `specializeResourceParameters` and `specializeResourceOutputs`. + bool specializeResourceUsage( + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* irModule); + + bool isIllegalGLSLParameterType(IRType* type); + } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp new file mode 100644 index 000000000..fcc6dc4ae --- /dev/null +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -0,0 +1,36 @@ +// slang-ir-ssa-simplification.cpp +#include "slang-ir-ssa-simplification.h" +#include "slang-ir.h" +#include "slang-ir-ssa.h" +#include "slang-ir-sccp.h" +#include "slang-ir-dce.h" +#include "slang-ir-simplify-cfg.h" + +namespace Slang +{ + struct IRModule; + + // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass + // until no more changes are possible. + void simplifyIR(IRModule* module) + { + bool changed = true; + const int kMaxIterations = 8; + int iterationCounter = 0; + while (changed && iterationCounter < kMaxIterations) + { + changed = false; + changed |= applySparseConditionalConstantPropagation(module); + changed |= simplifyCFG(module); + + // Note: we disregard the `changed` state from dead code elimination pass since + // SCCP pass could be generating temporarily evaluated constant values and never actually use them. + // DCE will always remove those nearly generated consts and always returns true here. + eliminateDeadCode(module); + + changed |= constructSSA(module); + + iterationCounter++; + } + } +} diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h new file mode 100644 index 000000000..19a39e8d4 --- /dev/null +++ b/source/slang/slang-ir-ssa-simplification.h @@ -0,0 +1,11 @@ +// slang-ir-ssa-simplification.h +#pragma once + +namespace Slang +{ + struct IRModule; + + // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass + // until no more changes are possible. + void simplifyIR(IRModule* module); +} diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 1804664fb..797fcb25c 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -1039,7 +1039,7 @@ static void breakCriticalEdges( } // Construct SSA form for a global value with code -void constructSSA(ConstructSSAContext* context) +bool constructSSA(ConstructSSAContext* context) { // First, detect and and break any critical edges in the CFG, // because our representation of SSA form doesn't allow for them. @@ -1052,7 +1052,7 @@ void constructSSA(ConstructSSAContext* context) // If none of the variables are promote-able, // then we can exit without making any changes if (context->promotableVars.getCount() == 0) - return; + return false; // We are going to walk the blocks in order, // and try to process each, by replacing loads @@ -1187,10 +1187,12 @@ void constructSSA(ConstructSSAContext* context) { var->removeAndDeallocate(); } + + return true; } // Construct SSA form for a global value with code -void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) +bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) { ConstructSSAContext context; context.globalVal = globalVal; @@ -1200,28 +1202,31 @@ void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) context.builder.init(context.sharedBuilder); context.builder.setInsertInto(module); - constructSSA(&context); + return constructSSA(&context); } -void constructSSA(IRModule* module, IRInst* globalVal) +bool constructSSA(IRModule* module, IRInst* globalVal) { switch (globalVal->getOp()) { case kIROp_Func: case kIROp_GlobalVar: - constructSSA(module, (IRGlobalValueWithCode*)globalVal); + return constructSSA(module, (IRGlobalValueWithCode*)globalVal); default: break; } + return false; } -void constructSSA(IRModule* module) +bool constructSSA(IRModule* module) { + bool changed = false; for(auto ii : module->getGlobalInsts()) { - constructSSA(module, ii); + changed |= constructSSA(module, ii); } + return changed; } } diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h index 635810c08..b327802a1 100644 --- a/source/slang/slang-ir-ssa.h +++ b/source/slang/slang-ir-ssa.h @@ -4,6 +4,7 @@ namespace Slang { struct IRModule; - - void constructSSA(IRModule* module); + struct IRGlobalValueWithCode; + bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal); + bool constructSSA(IRModule* module); } diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index ffadf33cf..e8da87187 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -196,6 +196,50 @@ namespace Slang } } + void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code) + { + HashSet<IRBlock*> blocks; + for (auto block : code->getBlocks()) + blocks.Add(block); + auto validateBranchTarget = [&](IRInst* inst, IRBlock* target) + { + validate( + context, + blocks.Contains(target), + inst, + "branch inst must have a valid target block that is defined within the same " + "scope."); + }; + for (auto block : code->getBlocks()) + { + auto terminator = block->getTerminator(); + validate(context, terminator, block, "block must have valid terminator inst."); + switch (terminator->getOp()) + { + case kIROp_conditionalBranch: + validateBranchTarget( + terminator, as<IRConditionalBranch>(terminator)->getTrueBlock()); + validateBranchTarget( + terminator, as<IRConditionalBranch>(terminator)->getFalseBlock()); + break; + case kIROp_loop: + case kIROp_unconditionalBranch: + validateBranchTarget(terminator, as<IRUnconditionalBranch>(terminator)->getTargetBlock()); + break; + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(terminator); + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + { + validateBranchTarget(switchInst, switchInst->getCaseLabel(i)); + } + validateBranchTarget(switchInst, switchInst->getDefaultLabel()); + validateBranchTarget(switchInst, switchInst->getBreakLabel()); + } + } + } + } + void validateIRInst( IRValidateContext* context, IRInst* inst) @@ -207,6 +251,11 @@ namespace Slang // If `inst` is itself a parent instruction, then we need to recursively // validate its children. validateIRInstChildren(context, inst); + + if (auto code = as<IRGlobalValueWithCode>(inst)) + { + validateCodeBody(context, code); + } } void validateIRModule(IRModule* module, DiagnosticSink* sink) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f7ebfdb64..721488f82 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3541,6 +3541,21 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitImageLoad(IRType* type, IRInst* image, IRInst* coord) + { + auto inst = createInst<IRImageLoad>(this, kIROp_ImageLoad, type, image, coord); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitImageStore(IRType* type, IRInst* image, IRInst* coord, IRInst* value) + { + IRInst* args[] = {image, coord, value}; + auto inst = createInst<IRImageStore>(this, kIROp_ImageStore, type, 3, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitFieldExtract( IRType* type, IRInst* base, @@ -5730,6 +5745,7 @@ namespace Slang case kIROp_makeArray: case kIROp_makeStruct: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads + case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: case kIROp_getElement: diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index fe32c93b1..e59573cef 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -773,7 +773,7 @@ struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl virtual LayoutRulesImpl* getVaryingInputRules() override; virtual LayoutRulesImpl* getVaryingOutputRules() override; virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override; virtual LayoutRulesImpl* getParameterBlockRules() override; LayoutRulesImpl* getRayPayloadParameterRules() override; @@ -794,7 +794,7 @@ struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl virtual LayoutRulesImpl* getVaryingInputRules() override; virtual LayoutRulesImpl* getVaryingOutputRules() override; virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override; virtual LayoutRulesImpl* getParameterBlockRules() override; LayoutRulesImpl* getRayPayloadParameterRules() override; @@ -815,7 +815,7 @@ struct CPULayoutRulesFamilyImpl : LayoutRulesFamilyImpl virtual LayoutRulesImpl* getVaryingInputRules() override; virtual LayoutRulesImpl* getVaryingOutputRules() override; virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override; virtual LayoutRulesImpl* getParameterBlockRules() override; LayoutRulesImpl* getRayPayloadParameterRules() override; @@ -835,7 +835,7 @@ struct CUDALayoutRulesFamilyImpl : LayoutRulesFamilyImpl virtual LayoutRulesImpl* getVaryingInputRules() override; virtual LayoutRulesImpl* getVaryingOutputRules() override; virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override; virtual LayoutRulesImpl* getParameterBlockRules() override; LayoutRulesImpl* getRayPayloadParameterRules() override; @@ -1140,8 +1140,10 @@ LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() return &kGLSLSpecializationConstantLayoutRulesImpl_; } -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest* request) { + if (request->getForceGLSLScalarBufferLayout()) + return &kHLSLStructuredBufferLayoutRulesImpl_; return &kStd430LayoutRulesImpl_; } @@ -1219,7 +1221,7 @@ LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() return nullptr; } -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*) { return nullptr; } @@ -1273,7 +1275,7 @@ LayoutRulesImpl* CPULayoutRulesFamilyImpl::getSpecializationConstantRules() { return nullptr; } -LayoutRulesImpl* CPULayoutRulesFamilyImpl::getShaderStorageBufferRules() +LayoutRulesImpl* CPULayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*) { return nullptr; } @@ -1339,7 +1341,7 @@ LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getSpecializationConstantRules() { return nullptr; } -LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getShaderStorageBufferRules() +LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*) { return nullptr; } @@ -2538,7 +2540,8 @@ static RefPtr<TypeLayout> _createParameterGroupTypeLayout( LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( ParameterGroupType* parameterGroupType, - LayoutRulesImpl* rules) + LayoutRulesImpl* rules, + TargetRequest* targetRequest) { if( as<ConstantBufferType>(parameterGroupType) ) { @@ -2558,7 +2561,7 @@ LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( } else if( as<GLSLShaderStorageBufferType>(parameterGroupType) ) { - return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(); + return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(targetRequest); } else if (as<ParameterBlockType>(parameterGroupType)) { @@ -2580,7 +2583,8 @@ RefPtr<TypeLayout> createParameterGroupTypeLayout( // Determine the layout rules to use for the contents of the block auto elementTypeRules = getParameterBufferElementTypeLayoutRules( parameterGroupType, - parameterGroupRules); + parameterGroupRules, + context.targetReq); auto elementType = parameterGroupType->elementType; diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index 6e28b6c9d..219dfad0f 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -1001,7 +1001,7 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getVaryingInputRules() = 0; virtual LayoutRulesImpl* getVaryingOutputRules() = 0; virtual LayoutRulesImpl* getSpecializationConstantRules()= 0; - virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0; + virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) = 0; virtual LayoutRulesImpl* getParameterBlockRules() = 0; virtual LayoutRulesImpl* getRayPayloadParameterRules() = 0; diff --git a/tests/bugs/sccp-switch-case-removal.slang b/tests/bugs/sccp-switch-case-removal.slang new file mode 100644 index 000000000..cf8d36ef7 --- /dev/null +++ b/tests/bugs/sccp-switch-case-removal.slang @@ -0,0 +1,25 @@ +//TEST(compute,vulkan):COMPARE_COMPUTE_EX:-vk -slang -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +static const uint kConstant = 5; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float x = 1.0; + switch(kConstant) + { + case 5: + int tid = int(dispatchThreadID.x); + outputBuffer[tid] = 0; + break; + case 1: + ddx(x); // this should trigger glslang error if it doesn't get eliminated by sccp. + break; + default: + ddy(x); // this should trigger glslang error if it doesn't get eliminated by sccp. + break; + } +}
\ No newline at end of file diff --git a/tests/bugs/sccp-switch-case-removal.slang.expected.txt b/tests/bugs/sccp-switch-case-removal.slang.expected.txt new file mode 100644 index 000000000..ae25f7400 --- /dev/null +++ b/tests/bugs/sccp-switch-case-removal.slang.expected.txt @@ -0,0 +1,4 @@ +0 +0 +0 +0
\ No newline at end of file diff --git a/tests/bugs/vk-image-write.slang b/tests/bugs/vk-image-write.slang new file mode 100644 index 000000000..10141384c --- /dev/null +++ b/tests/bugs/vk-image-write.slang @@ -0,0 +1,16 @@ +//TEST:CROSS_COMPILE: -profile ps_5_0 -entry main -target spirv-assembly + +// Ensure that we can lower to `imageStore` correctly. + +RWTexture2D<float4> t; + +void writeColor(float3 v) +{ + t[uint2(0,0)].xyz += v; +} + +float4 main() : SV_Target +{ + writeColor(float3(1.0)); + return float4(0); +} diff --git a/tests/bugs/vk-image-write.slang.glsl b/tests/bugs/vk-image-write.slang.glsl new file mode 100644 index 000000000..7b9bcebf9 --- /dev/null +++ b/tests/bugs/vk-image-write.slang.glsl @@ -0,0 +1,40 @@ +//TEST_IGNORE_FILE: + +#version 450 +layout(row_major) uniform; +layout(row_major) buffer; + +layout(rgba32f) +layout(binding = 0) +uniform image2D t_0; + +void writeColor_0(vec3 v_0) +{ + const uvec2 _S1 = uvec2(0U, 0U); + + vec4 _S2 = (imageLoad((t_0), ivec2((_S1)))); + + vec3 _S3 = _S2.xyz + v_0; + + ivec2 _S4 = ivec2(_S1); + + vec4 _S5 = imageLoad(t_0,_S4); + + vec4 _S6 = _S5; + _S6.xyz = _S3; + + imageStore(t_0,_S4,_S6); + return; +} + + +layout(location = 0) +out vec4 _S7; + +void main() +{ + writeColor_0(vec3(1.00000000000000000000)); + _S7 = vec4(0); + return; +} + diff --git a/tests/bugs/vk-structured-buffer-load.hlsl.glsl b/tests/bugs/vk-structured-buffer-load.hlsl.glsl index 1c7ec8043..05c8de193 100644 --- a/tests/bugs/vk-structured-buffer-load.hlsl.glsl +++ b/tests/bugs/vk-structured-buffer-load.hlsl.glsl @@ -43,13 +43,12 @@ void main() float HitT_0 = (gl_RayTmaxNV); RayData.PackedHitInfoA_0.x = HitT_0; - const uint use_rcp_0 = uint(0); - float offsfloat_0 = ((gParamBlock_sbuf_0)._data[(int(uint(0)))]); + float offsfloat_0 = ((gParamBlock_sbuf_0)._data[(0)]); - uint use_rcp_1 = use_rcp_0|uint(HitT_0 > 0.00000000000000000000); + uint use_rcp_0 = 0U | uint(HitT_0 > 0.00000000000000000000); - if(bool(use_rcp_1)) + if(bool(use_rcp_0)) { float tmpA = rcp_0(offsfloat_0); @@ -60,7 +59,7 @@ void main() else { - if(use_rcp_1 > uint(0)&&offsfloat_0 == 0.00000000000000000000) + if(use_rcp_0 > 0U&&offsfloat_0 == 0.00000000000000000000) { float tmpB = (inversesqrt((offsfloat_0 + 1.00000000000000000000))); diff --git a/tests/optimization/func-resource-result/func-resource-result-complex.slang b/tests/optimization/func-resource-result/func-resource-result-complex.slang new file mode 100644 index 000000000..a5585ff4c --- /dev/null +++ b/tests/optimization/func-resource-result/func-resource-result-complex.slang @@ -0,0 +1,45 @@ +// func-resource-result-simple.slang + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj + +// Test that a function that returns a resource type can be +// compiled for targets that don't natively support resource +// return values. + +//TEST_INPUT:set textures=[Texture2D(size=4, content = zero), Texture2D(size=4, content = one)] +Texture2D textures[2]; + +//TEST_INPUT:set sampler=Sampler +SamplerState sampler; + +Texture2D getTex(int index) +{ + Texture2D result; + // Note: `index` here will need to be a compile time constant in order to generate + // valid GLSL. If constant folding and function inlining are all done correctly + // we should be able to compile this. + if (index == 0) + result = textures[1]; + else + result = textures[0]; + return result; +} + +int test(int val) +{ + // Make sure index is a compile-time constant. + return getTex(int(0.0) + 1*2 < 5 ? 0 : 1).SampleLevel(sampler, float2(0,0), 0).x == 0.0 ? 0 : 1; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt b/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt new file mode 100644 index 000000000..ef529012e --- /dev/null +++ b/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt @@ -0,0 +1,4 @@ +1 +1 +1 +1
\ No newline at end of file diff --git a/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl b/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl index 1454d493f..59fecd544 100644 --- a/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl +++ b/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl @@ -1,14 +1,14 @@ +//TEST_IGNORE_FILE: #version 450 - #extension GL_NV_conservative_raster_underestimation : require +layout(row_major) uniform; +layout(row_major) buffer; layout(location = 0) out vec4 _S1; void main() { - vec4 _S2; - _S2 = vec4(uint(gl_FragFullyCoveredNV)); - _S1 = _S2; + _S1 = vec4(uint(gl_FragFullyCoveredNV)); return; } diff --git a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl index a12b9827b..1818b7789 100644 --- a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl +++ b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl @@ -2,6 +2,8 @@ #version 450 #extension GL_ARB_fragment_shader_interlock : require +layout(row_major) uniform; +layout(row_major) buffer; layout(rgba32f) layout(binding = 0) @@ -15,13 +17,9 @@ out vec4 _S2; void main() { - vec4 _S3; - beginInvocationInterlockARB(); - vec4 _S4 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))))); - - _S3 = _S4; + vec4 _S3 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))))); imageStore((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))), _S3 + _S1); endInvocationInterlockARB(); diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl index c07e9b61c..1da5f4f8a 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl @@ -2,24 +2,18 @@ //TEST_IGNORE_FILE: #version 450 - #extension GL_NV_fragment_shader_barycentric : require +layout(row_major) uniform; +layout(row_major) buffer; pervertexNV layout(location = 0) -in vec4 _S1[3]; +in vec4 _S1[3]; layout(location = 0) out vec4 _S2; void main() { - vec4 _S3; - - _S3 = gl_BaryCoordNV.x * _S1[0] - + gl_BaryCoordNV.y * _S1[1] - + gl_BaryCoordNV.z * _S1[2]; - - _S2 = _S3; - + _S2 = gl_BaryCoordNV.x * ((_S1)[(0U)]) + gl_BaryCoordNV.y * ((_S1)[(1U)]) + gl_BaryCoordNV.z * ((_S1)[(2U)]); return; } diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index abe5a3b4d..7cfdf06c8 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -399,14 +399,20 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL setCurrentValue(uint64_t value) override { - VkSemaphoreSignalInfo signalInfo; - signalInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; - signalInfo.pNext = NULL; - signalInfo.semaphore = m_semaphore; - signalInfo.value = 2; + uint64_t currentValue = 0; + SLANG_VK_RETURN_ON_FAIL(m_device->m_api.vkGetSemaphoreCounterValue( + m_device->m_api.m_device, m_semaphore, ¤tValue)); + if (currentValue < value) + { + VkSemaphoreSignalInfo signalInfo; + signalInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; + signalInfo.pNext = NULL; + signalInfo.semaphore = m_semaphore; + signalInfo.value = value; - SLANG_VK_RETURN_ON_FAIL( - m_device->m_api.vkSignalSemaphore(m_device->m_api.m_device, &signalInfo)); + SLANG_VK_RETURN_ON_FAIL( + m_device->m_api.vkSignalSemaphore(m_device->m_api.m_device, &signalInfo)); + } return SLANG_OK; } @@ -881,7 +887,7 @@ public: ComPtr<IResourceView> depthStencilView; uint32_t m_width; uint32_t m_height; - RefPtr<VKDevice> m_renderer; + BreakableReference<VKDevice> m_renderer; VkClearValue m_clearValues[kMaxAttachments]; RefPtr<FramebufferLayoutImpl> m_layout; public: @@ -911,7 +917,7 @@ public: m_height = getMipLevelSize(viewDesc->subresourceRange.mipLevel, size.height); layerCount = viewDesc->subresourceRange.layerCount; } - else + else if (desc.renderTargetCount) { // If we don't have a depth attachment, then we must have at least // one color attachment. Get frame dimension from there. @@ -923,6 +929,12 @@ public: m_height = getMipLevelSize(viewDesc->subresourceRange.mipLevel, size.height); layerCount = viewDesc->subresourceRange.layerCount; } + else + { + m_width = 1; + m_height = 1; + layerCount = 1; + } if (layerCount == 0) layerCount = 1; // Create render pass. @@ -2978,19 +2990,20 @@ public: Index count = resourceViews.getCount(); for(Index i = 0; i < count; ++i) { - auto bufferView = static_cast<PlainBufferResourceViewImpl*>(resourceViews[i].Ptr()); - VkDescriptorBufferInfo bufferInfo = {}; + bufferInfo.range = VK_WHOLE_SIZE; - if(bufferView) - { - bufferInfo.buffer = bufferView->m_buffer->m_buffer.m_buffer; - bufferInfo.offset = bufferView->offset; - bufferInfo.range = bufferView->size; - } - else + if (resourceViews[i]) { - bufferInfo.range = VK_WHOLE_SIZE; + auto boundViewType = static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type; + if (boundViewType == ResourceViewImpl::ViewType::PlainBuffer) + { + auto bufferView = + static_cast<PlainBufferResourceViewImpl*>(resourceViews[i].Ptr()); + bufferInfo.buffer = bufferView->m_buffer->m_buffer.m_buffer; + bufferInfo.offset = bufferView->offset; + bufferInfo.range = bufferView->size; + } } VkWriteDescriptorSet write = {}; @@ -3017,11 +3030,17 @@ public: Index count = resourceViews.getCount(); for(Index i = 0; i < count; ++i) { - auto resourceView = static_cast<TexelBufferResourceViewImpl*>(resourceViews[i].Ptr()); VkBufferView bufferView = VK_NULL_HANDLE; - if (resourceView) + if (resourceViews[i]) { - bufferView = resourceView->m_view; + auto boundViewType = + static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type; + if (boundViewType == ResourceViewImpl::ViewType::TexelBuffer) + { + auto resourceView = + static_cast<TexelBufferResourceViewImpl*>(resourceViews[i].Ptr()); + bufferView = resourceView->m_view; + } } VkWriteDescriptorSet write = {}; write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; @@ -3125,12 +3144,17 @@ public: Index count = resourceViews.getCount(); for(Index i = 0; i < count; ++i) { - auto texture = static_cast<TextureResourceViewImpl*>(resourceViews[i].Ptr()); VkDescriptorImageInfo imageInfo = {}; - if (texture) + if (resourceViews[i]) { - imageInfo.imageView = texture->m_view; - imageInfo.imageLayout = texture->m_layout; + auto boundViewType = + static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type; + if (boundViewType == ResourceViewImpl::ViewType::Texture) + { + auto texture = static_cast<TextureResourceViewImpl*>(resourceViews[i].Ptr()); + imageInfo.imageView = texture->m_view; + imageInfo.imageLayout = texture->m_layout; + } } imageInfo.sampler = 0; @@ -5081,6 +5105,8 @@ public: void beginPass(IRenderPassLayout* renderPass, IFramebuffer* framebuffer) { FramebufferImpl* framebufferImpl = static_cast<FramebufferImpl*>(framebuffer); + if (!framebuffer) + framebufferImpl = this->m_device->m_emptyFramebuffer; RenderPassLayoutImpl* renderPassImpl = static_cast<RenderPassLayoutImpl*>(renderPass); VkClearValue clearValues[kMaxAttachments] = {}; @@ -6526,6 +6552,8 @@ public: ChunkedList<RefPtr<RefObject>, 1024> m_deviceObjectsWithPotentialBackReferences; VkSampler m_defaultSampler; + + RefPtr<FramebufferImpl> m_emptyFramebuffer; }; void VKDevice::PipelineCommandEncoder::init(CommandBufferImpl* commandBuffer) @@ -6704,7 +6732,9 @@ VKDevice::~VKDevice() m_deviceQueue.destroy(); descriptorSetAllocator.close(); - + + m_emptyFramebuffer = nullptr; + if (m_device != VK_NULL_HANDLE) { if (m_desc.existingDeviceHandles.handles[2].handleValue == 0) @@ -7187,6 +7217,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool deviceExtensions.add(VK_KHR_RAY_QUERY_EXTENSION_NAME); m_features.add("ray-query"); m_features.add("ray-tracing"); + m_features.add("sm_6_6"); } if (extendedFeatures.bufferDeviceAddressFeatures.bufferDeviceAddress) @@ -7243,6 +7274,13 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool #endif m_features.add("external-memory"); } + if (extensionNames.Contains(VK_EXT_CONSERVATIVE_RASTERIZATION_EXTENSION_NAME)) + { + deviceExtensions.add(VK_EXT_CONSERVATIVE_RASTERIZATION_EXTENSION_NAME); + m_features.add("conservative-rasterization-3"); + m_features.add("conservative-rasterization-2"); + m_features.add("conservative-rasterization-1"); + } if (extensionNames.Contains(VK_EXT_DEBUG_REPORT_EXTENSION_NAME)) { deviceExtensions.add(VK_EXT_DEBUG_REPORT_EXTENSION_NAME); @@ -7358,6 +7396,21 @@ SlangResult VKDevice::initialize(const Desc& desc) SLANG_VK_RETURN_ON_FAIL(m_api.vkCreateSampler(m_device, &samplerInfo, nullptr, &m_defaultSampler)); } + // Create empty frame buffer. + { + IFramebufferLayout::Desc layoutDesc = {}; + layoutDesc.renderTargetCount = 0; + layoutDesc.depthStencil = nullptr; + ComPtr<IFramebufferLayout> layout; + SLANG_RETURN_ON_FAIL(createFramebufferLayout(layoutDesc, layout.writeRef())); + IFramebuffer::Desc desc = {}; + desc.layout = layout; + ComPtr<IFramebuffer> framebuffer; + SLANG_RETURN_ON_FAIL(createFramebuffer(desc, framebuffer.writeRef())); + m_emptyFramebuffer = static_cast<FramebufferImpl*>(framebuffer.get()); + m_emptyFramebuffer->m_renderer.breakStrongReference(); + } + return SLANG_OK; } @@ -9194,6 +9247,16 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias; rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported + VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {}; + conservativeRasterInfo.sType = + VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT; + conservativeRasterInfo.conservativeRasterizationMode = + VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT; + if (desc.rasterizer.enableConservativeRasterization) + { + rasterizer.pNext = &conservativeRasterInfo; + } + auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout); auto forcedSampleCount = rasterizerDesc.forcedSampleCount; auto blendDesc = desc.blend; |
