summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-25 20:49:31 -0800
committerGitHub <noreply@github.com>2022-02-25 20:49:31 -0800
commitc31577953d5041c82375c22d847c2eba06106c58 (patch)
treebc685a8b63fc13cb85d160ae13df950056ca6e91
parent8990d270e3a0c01b1f7abbf4f79556c5ef82a096 (diff)
Improved SCCP, inlining and resource specialization passes, legalize `ImageSubscript` for GLSL (#2146)
-rw-r--r--build/visual-studio/slang/slang.vcxproj4
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters12
-rw-r--r--source/slang/hlsl.meta.slang66
-rw-r--r--source/slang/slang-emit-c-like.cpp13
-rw-r--r--source/slang/slang-emit-cuda.cpp18
-rw-r--r--source/slang/slang-emit-cuda.h1
-rw-r--r--source/slang/slang-emit-glsl.cpp32
-rw-r--r--source/slang/slang-emit.cpp74
-rw-r--r--source/slang/slang-ir-dce.cpp15
-rw-r--r--source/slang/slang-ir-dce.h4
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp145
-rw-r--r--source/slang/slang-ir-glsl-legalize.h2
-rw-r--r--source/slang/slang-ir-inline.cpp284
-rw-r--r--source/slang/slang-ir-inline.h7
-rw-r--r--source/slang/slang-ir-inst-defs.h5
-rw-r--r--source/slang/slang-ir-insts.h32
-rw-r--r--source/slang/slang-ir-sccp.cpp729
-rw-r--r--source/slang/slang-ir-sccp.h3
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp82
-rw-r--r--source/slang/slang-ir-simplify-cfg.h12
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp19
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp11
-rw-r--r--source/slang/slang-ir-specialize-function-call.h4
-rw-r--r--source/slang/slang-ir-specialize-resources.cpp142
-rw-r--r--source/slang/slang-ir-specialize-resources.h12
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp36
-rw-r--r--source/slang/slang-ir-ssa-simplification.h11
-rw-r--r--source/slang/slang-ir-ssa.cpp21
-rw-r--r--source/slang/slang-ir-ssa.h5
-rw-r--r--source/slang/slang-ir-validate.cpp49
-rw-r--r--source/slang/slang-ir.cpp16
-rw-r--r--source/slang/slang-type-layout.cpp26
-rw-r--r--source/slang/slang-type-layout.h2
-rw-r--r--tests/bugs/sccp-switch-case-removal.slang25
-rw-r--r--tests/bugs/sccp-switch-case-removal.slang.expected.txt4
-rw-r--r--tests/bugs/vk-image-write.slang16
-rw-r--r--tests/bugs/vk-image-write.slang.glsl40
-rw-r--r--tests/bugs/vk-structured-buffer-load.hlsl.glsl9
-rw-r--r--tests/optimization/func-resource-result/func-resource-result-complex.slang45
-rw-r--r--tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt4
-rw-r--r--tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl8
-rw-r--r--tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl8
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl14
-rw-r--r--tools/gfx/vulkan/render-vk.cpp117
44 files changed, 1927 insertions, 257 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 4c6bcbea2..c2387346d 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -377,6 +377,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-sccp.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-dispatch.h" />
@@ -386,6 +387,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-string-hash.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-strip-witness-tables.h" />
@@ -508,6 +510,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-sccp.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-dispatch.cpp" />
@@ -517,6 +520,7 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-string-hash.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-strip-witness-tables.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index caea98ea9..d86cdcbea 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -228,6 +228,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-sccp.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -255,6 +258,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -617,6 +623,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-sccp.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -644,6 +653,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 2e1ab33f2..d3d77b804 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -64,6 +64,11 @@ __target_intrinsic(hlsl, "NvInterlockedAddUint64($0, $1, $2)")
[__requiresNVAPI]
uint2 __atomicAdd(RWByteAddressBuffer buf, uint offset, uint2);
+// atomic add for hlsl using SM6.6
+__target_intrinsic(hlsl, "$0.InterlockedAdd64($1, $2, $3)")
+void __atomicAdd(RWByteAddressBuffer buf, uint offset, int64_t value, out int64_t originalValue);
+__target_intrinsic(hlsl, "$0.InterlockedAdd64($1, $2, $3)")
+void __atomicAdd(RWByteAddressBuffer buf, uint offset, uint64_t value, out uint64_t originalValue);
// Int versions require glsl 4.30
// https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/atomicAdd.xhtml
@@ -81,6 +86,11 @@ __glsl_version(430)
__glsl_extension(GL_EXT_shader_atomic_int64)
int64_t __atomicAdd(__ref int64_t value, int64_t amount);
+__target_intrinsic(glsl, "atomicAdd($0, $1)")
+__glsl_version(430)
+__glsl_extension(GL_EXT_shader_atomic_int64)
+uint64_t __atomicAdd(__ref uint64_t value, uint64_t amount);
+
// Cas - Compare and swap
// Helper for HLSL, using NVAPI
@@ -89,6 +99,17 @@ __target_intrinsic(hlsl, "NvInterlockedCompareExchangeUint64($0, $1, $2, $3)")
[__requiresNVAPI]
uint2 __cas(RWByteAddressBuffer buf, uint offset, uint2 compareValue, uint2 value);
+// CAS using SM6.6
+__target_intrinsic(hlsl, "$0.InterlockedCompareExchange64($1, $2, $3, $4)")
+void __cas(RWByteAddressBuffer buf, uint offset, in int64_t compare_value, in int64_t value, out int64_t original_value);
+__target_intrinsic(hlsl, "$0.InterlockedCompareExchange64($1, $2, $3, $4)")
+void __cas(RWByteAddressBuffer buf, uint offset, in uint64_t compare_value, in uint64_t value, out uint64_t original_value);
+
+__target_intrinsic(glsl, "atomicCompSwap($0, $1, $2)")
+__glsl_version(430)
+__glsl_extension(GL_EXT_shader_atomic_int64)
+uint64_t __cas(__ref int64_t ioValue, int64_t compareValue, int64_t newValue);
+
__target_intrinsic(glsl, "atomicCompSwap($0, $1, $2)")
__glsl_version(430)
__glsl_extension(GL_EXT_shader_atomic_int64)
@@ -482,6 +503,51 @@ ${{{{
return __atomicExchange(buf[byteAddress / 8], value);
}
+ // SM6.6 6 64bit atomics.
+ __specialized_for_target(hlsl)
+ void InterlockedAdd64(uint byteAddress, int64_t valueToAdd, out int64_t outOriginalValue)
+ {
+ __atomicAdd(this, byteAddress, valueToAdd, outOriginalValue);
+ }
+ __specialized_for_target(glsl)
+ void InterlockedAdd64(uint byteAddress, int64_t valueToAdd, out int64_t originalValue)
+ {
+ RWStructuredBuffer<int64_t> buf = __getEquivalentStructuredBuffer<int64_t>(this);
+ originalValue = __atomicAdd(buf[byteAddress / 8], valueToAdd);
+ }
+ __specialized_for_target(hlsl)
+ void InterlockedAdd64(uint byteAddress, uint64_t valueToAdd, out uint64_t outOriginalValue)
+ {
+ __atomicAdd(this, byteAddress, valueToAdd, outOriginalValue);
+ }
+ __specialized_for_target(glsl)
+ void InterlockedAdd64(uint byteAddress, uint64_t valueToAdd, out uint64_t originalValue)
+ {
+ RWStructuredBuffer<uint64_t> buf = __getEquivalentStructuredBuffer<uint64_t>(this);
+ originalValue = __atomicAdd(buf[byteAddress / 8], valueToAdd);
+ }
+ __specialized_for_target(hlsl)
+ void InterlockedCompareExchange64(uint byteAddress, int64_t compareValue, int64_t value, out int64_t outOriginalValue)
+ {
+ __cas(this, byteAddress, compareValue, value, outOriginalValue);
+ }
+ __specialized_for_target(glsl)
+ void InterlockedCompareExchange64(uint byteAddress, int64_t compareValue, int64_t value, out int64_t outOriginalValue)
+ {
+ RWStructuredBuffer<int64_t> buf = __getEquivalentStructuredBuffer<int64_t>(this);
+ outOriginalValue = __cas(buf[byteAddress / 8], compareValue, value);
+ }
+ __specialized_for_target(hlsl)
+ void InterlockedCompareExchange64(uint byteAddress, uint64_t compareValue, uint64_t value, out uint64_t outOriginalValue)
+ {
+ __cas(this, byteAddress, compareValue, value, outOriginalValue);
+ }
+ __specialized_for_target(glsl)
+ void InterlockedCompareExchange64(uint byteAddress, uint64_t compareValue, uint64_t value, out uint64_t outOriginalValue)
+ {
+ RWStructuredBuffer<uint64_t> buf = __getEquivalentStructuredBuffer<uint64_t>(this);
+ outOriginalValue = __cas(buf[byteAddress / 8], compareValue, value);
+ }
${{{{
}
}}}}
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 03943abb3..6b77cba6a 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -893,10 +893,20 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
return;
}
case BaseType::UInt8:
+ {
+ m_writer->emit(UInt(uint8_t(litInst->value.intVal)));
+ m_writer->emit("U");
+ break;
+ }
case BaseType::UInt16:
+ {
+ m_writer->emit(UInt(uint16_t(litInst->value.intVal)));
+ m_writer->emit("U");
+ break;
+ }
case BaseType::UInt:
{
- m_writer->emit(UInt(litInst->value.intVal));
+ m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
m_writer->emit("U");
break;
}
@@ -1010,6 +1020,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
//
case kIROp_makeStruct:
case kIROp_makeArray:
+ case kIROp_swizzleSet:
return false;
}
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 27eb75d34..f7850179d 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -895,6 +895,24 @@ void CUDASourceEmitter::emitSimpleFuncImpl(IRFunc* func)
CLikeSourceEmitter::emitSimpleFuncImpl(func);
}
+void CUDASourceEmitter::emitSimpleValueImpl(IRInst* inst)
+{
+ // Make sure we convert float to half when emitting a half literal to avoid
+ // overload ambiguity errors from CUDA.
+ if (inst->getOp() == kIROp_FloatLit)
+ {
+ if (inst->getDataType()->getOp() == kIROp_HalfType)
+ {
+ m_writer->emit("__half(");
+ CLikeSourceEmitter::emitSimpleValueImpl(inst);
+ m_writer->emit(")");
+ return;
+ }
+ }
+ CLikeSourceEmitter::emitSimpleValueImpl(inst);
+}
+
+
void CUDASourceEmitter::emitSemanticsImpl(IRInst* inst)
{
Super::emitSemanticsImpl(inst);
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index 07f429898..ad2f8ed0a 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -72,6 +72,7 @@ protected:
virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE;
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
+ virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 441fc94d7..44326a18e 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -774,38 +774,38 @@ void GLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
{
emitType(type);
m_writer->emit("(");
- m_writer->emit(litInst->value.intVal);
+ m_writer->emit(int8_t(litInst->value.intVal));
m_writer->emit(")");
return;
}
case BaseType::Int16:
{
- m_writer->emit(litInst->value.intVal);
+ m_writer->emit(int16_t(litInst->value.intVal));
m_writer->emit("S");
return;
}
case BaseType::Int:
{
- m_writer->emit(litInst->value.intVal);
+ m_writer->emit(int32_t(litInst->value.intVal));
return;
}
case BaseType::UInt8:
{
emitType(type);
m_writer->emit("(");
- m_writer->emit(UInt(litInst->value.intVal));
+ m_writer->emit(UInt(uint8_t(litInst->value.intVal)));
m_writer->emit("U)");
return;
}
case BaseType::UInt16:
{
- m_writer->emit(UInt(litInst->value.intVal));
+ m_writer->emit(UInt(uint16_t(litInst->value.intVal)));
m_writer->emit("US");
return;
}
case BaseType::UInt:
{
- m_writer->emit(UInt(litInst->value.intVal));
+ m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
m_writer->emit("U");
return;
}
@@ -1636,6 +1636,26 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
emitOperand(inst->getOperand(0), outerPrec);
return true;
}
+ case kIROp_ImageLoad:
+ {
+ m_writer->emit("imageLoad(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_ImageStore:
+ {
+ m_writer->emit("imageStore(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
case kIROp_StructuredBufferLoad:
{
auto outerPrec = inOuterPrec;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 378732fb3..a0ac30857 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -26,11 +26,13 @@
#include "slang-ir-optix-entry-point-uniforms.h"
#include "slang-ir-restructure.h"
#include "slang-ir-restructure-scoping.h"
+#include "slang-ir-sccp.h"
#include "slang-ir-specialize.h"
#include "slang-ir-specialize-arrays.h"
#include "slang-ir-specialize-buffer-load-arg.h"
#include "slang-ir-specialize-resources.h"
#include "slang-ir-ssa.h"
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-strip-witness-tables.h"
#include "slang-ir-synthesize-active-mask.h"
#include "slang-ir-union.h"
@@ -324,10 +326,13 @@ Result linkAndOptimizeIR(
specializeModule(irModule);
dumpIRIfEnabled(compileRequest, irModule, "AFTER-SPECIALIZE");
+ applySparseConditionalConstantPropagation(irModule);
eliminateDeadCode(irModule);
lowerReinterpret(targetRequest, irModule, sink);
+ validateIRModuleIfEnabled(compileRequest, irModule);
+
// For targets that supports dynamic dispatch, we need to lower the
// generics / interface types to ordinary functions and types using
// function pointers.
@@ -359,10 +364,7 @@ Result linkAndOptimizeIR(
// up downstream passes like type legalization, so we
// will run a DCE pass to clean up after the specialization.
//
- // TODO: Are there other cleanup optimizations we should
- // apply at this point?
- //
- eliminateDeadCode(irModule);
+ simplifyIR(irModule);
#if 0
dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE");
#endif
@@ -435,7 +437,7 @@ Result linkAndOptimizeIR(
// to see if we can clean up any temporaries created by legalization.
// (e.g., things that used to be aggregated might now be split up,
// so that we can work with the individual fields).
- constructSSA(irModule);
+ simplifyIR(irModule);
#if 0
dumpIRIfEnabled(compileRequest, irModule, "AFTER SSA");
@@ -450,36 +452,12 @@ Result linkAndOptimizeIR(
// Many of our targets place restrictions on how certain
// resource types can be used, so that having them as
// function parameters, reults, etc. is invalid.
- // To clean this up, we apply two kinds of specialization:
- //
- // * Specalize call sites based on the actual resources
- // that a called function will return/output.
- //
- // * Specialize called functions based on teh actual resources
- // passed ass input at specific call sites.
- //
- // Because the legalization may depend on what target
- // we are compiling for (certain things might be okay
- // for D3D targets that are not okay for Vulkan), we
- // pass down the target request along with the IR.
- //
- specializeResourceOutputs(compileRequest, targetRequest, irModule);
- //
- // After specialization of function outputs, we may find that there
- // are cases where opaque-typed local variables can now be eliminated
- // and turned into SSA temporaries. Such optimization may enable
- // the following passes to "see" and specialize more cases.
- //
- // TODO: We should consider whether there are cases that will require
- // iterating the passes as given here in order to achieve a fully
- // specialized result. If that is the case, we might consider implementing
- // a single combined pass that makes all of the relevant changes and
- // iterates to convergence.
- //
- constructSSA(irModule);
- //
+ // We clean up the usages of resource values here.
+ specializeResourceUsage(compileRequest, targetRequest, irModule);
specializeFuncsForBufferLoadArgs(compileRequest, targetRequest, irModule);
- specializeResourceParameters(compileRequest, targetRequest, irModule);
+
+ //
+ simplifyIR(irModule);
// For GLSL targets, we also want to specialize calls to functions that
// takes array parameters if possible, to avoid performance issues on
@@ -487,6 +465,7 @@ Result linkAndOptimizeIR(
if (isKhronosTarget(targetRequest))
{
specializeArrayParameters(compileRequest, targetRequest, irModule);
+ simplifyIR(irModule);
}
#if 0
@@ -675,6 +654,17 @@ Result linkAndOptimizeIR(
break;
}
+ // Legalize `ImageSubscript` for GLSL.
+ switch (target)
+ {
+ case CodeGenTarget::GLSL:
+ {
+ legalizeImageSubscriptForGLSL(irModule);
+ }
+ break;
+ default:
+ break;
+ }
switch( target )
{
@@ -712,11 +702,16 @@ Result linkAndOptimizeIR(
// functions, so there might still be invalid code in
// our IR module.
//
- // To clean up the code, we will apply a fairly general
- // dead-code-elimination (DCE) pass that only retains
- // whatever code is "live."
+ // We run IR simplification passes again to clean things up.
//
- eliminateDeadCode(irModule);
+ simplifyIR(irModule);
+
+ if (isKhronosTarget(targetRequest))
+ {
+ // As a fallback, if the above specialization steps failed to remove resource type parameters, we will
+ // inline the functions in question to make sure we can produce valid GLSL.
+ performGLSLResourceReturnFunctionInlining(irModule);
+ }
#if 0
dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE");
#endif
@@ -725,8 +720,7 @@ Result linkAndOptimizeIR(
// Lower all bit_cast operations on complex types into leaf-level
// bit_cast on basic types.
lowerBitCast(targetRequest, irModule);
- eliminateDeadCode(irModule);
-
+ simplifyIR(irModule);
// We include one final step to (optionally) dump the IR and validate
// it after all of the optimization passes are complete. This should
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 285e5100c..9ed5249fe 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -81,7 +81,7 @@ struct DeadCodeEliminationContext
// dive into the task of actually finding all
// the live code in a module.
//
- void processModule()
+ bool processModule()
{
// First of all, we know that the root module instruction
// should be considered as live, because otherwise
@@ -174,11 +174,12 @@ struct DeadCodeEliminationContext
// recursively and eliminate those that are "dead" by
// virtue of not having been found live.
//
- eliminateDeadInstsRec(module->getModuleInst());
+ return eliminateDeadInstsRec(module->getModuleInst());
}
- void eliminateDeadInstsRec(IRInst* inst)
+ bool eliminateDeadInstsRec(IRInst* inst)
{
+ bool changed = false;
// Given the instruction `inst` we need to eliminate
// any dead code at, or under it.
//
@@ -192,6 +193,7 @@ struct DeadCodeEliminationContext
// mark the parent of a live instruction as live).
//
inst->removeAndDeallocate();
+ changed = true;
}
else
{
@@ -208,9 +210,10 @@ struct DeadCodeEliminationContext
for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next )
{
next = child->getNextInst();
- eliminateDeadInstsRec(child);
+ changed |= eliminateDeadInstsRec(child);
}
}
+ return changed;
}
// Now we come to the decision procedure we put off before:
@@ -336,7 +339,7 @@ struct DeadCodeEliminationContext
// is straighforward. We set up the context object
// and then defer to it for the real work.
//
-void eliminateDeadCode(
+bool eliminateDeadCode(
IRModule* module,
IRDeadCodeEliminationOptions const& options)
{
@@ -344,7 +347,7 @@ void eliminateDeadCode(
context.module = module;
context.options = options;
- context.processModule();
+ return context.processModule();
}
}
diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h
index 007905486..aac10a20a 100644
--- a/source/slang/slang-ir-dce.h
+++ b/source/slang/slang-ir-dce.h
@@ -17,8 +17,8 @@ namespace Slang
/// "global" dead code elimination (DCE), such as removing
/// types that are unused, functions that are never called,
/// etc.
- ///
- void eliminateDeadCode(
+ /// Returns true if changed.
+ bool eliminateDeadCode(
IRModule* module,
IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions());
}
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index bb40ecca9..83b75033a 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -8,8 +8,141 @@
namespace Slang
{
+int getIRVectorElementSize(IRType* type)
+{
+ if (type->getOp() != kIROp_VectorType)
+ return 1;
+ return (int)(as<IRIntLit>(as<IRVectorType>(type)->getElementCount())->value.intVal);
+}
+IRType* getIRVectorBaseType(IRType* type)
+{
+ if (type->getOp() != kIROp_VectorType)
+ return type;
+ return as<IRVectorType>(type)->getElementType();
+}
-//
+void legalizeImageSubscriptStoreForGLSL(IRBuilder& builder, IRInst* storeInst)
+{
+ builder.setInsertBefore(storeInst);
+ auto imageSubscript = as<IRImageSubscript>(storeInst->getOperand(0));
+ assert(imageSubscript);
+ auto imageElementType = cast<IRPtrTypeBase>(imageSubscript->getDataType())->getValueType();
+ auto coordType = imageSubscript->getCoord()->getDataType();
+ auto coordVectorSize = getIRVectorElementSize(coordType);
+ if (coordVectorSize != 1)
+ {
+ coordType = builder.getVectorType(
+ builder.getIntType(), builder.getIntValue(builder.getIntType(), coordVectorSize));
+ }
+ else
+ {
+ coordType = builder.getIntType();
+ }
+ auto legalizedCoord = imageSubscript->getCoord();
+ if (coordType != imageSubscript->getCoord()->getDataType())
+ {
+ legalizedCoord = builder.emitConstructorInst(coordType, 1, &legalizedCoord);
+ }
+ switch (storeInst->getOp())
+ {
+ case kIROp_Store:
+ {
+ auto newValue = storeInst->getOperand(1);
+ if (getIRVectorElementSize(imageElementType) != 4)
+ {
+ auto vectorBaseType = getIRVectorBaseType(imageElementType);
+ newValue = builder.emitConstructorInst(
+ builder.getVectorType(
+ vectorBaseType, builder.getIntValue(builder.getIntType(), 4)),
+ 1,
+ &newValue);
+ }
+ auto imageStore = builder.emitImageStore(
+ builder.getVoidType(),
+ imageSubscript->getImage(),
+ legalizedCoord,
+ newValue);
+ storeInst->replaceUsesWith(imageStore);
+ storeInst->removeAndDeallocate();
+ if (!imageSubscript->hasUses())
+ {
+ imageSubscript->removeAndDeallocate();
+ }
+ }
+ break;
+ case kIROp_SwizzledStore:
+ {
+ auto swizzledStore = cast<IRSwizzledStore>(storeInst);
+ // Here we assume the imageElementType is already lowered into float4/uint4 types from any
+ // user-defined type.
+ assert(imageElementType->getOp() == kIROp_VectorType);
+ auto originalValue = builder.emitImageLoad(imageElementType, imageSubscript->getImage(), legalizedCoord);
+ Array<IRInst*, 4> indices;
+ for (UInt i = 0; i < swizzledStore->getElementCount(); i++)
+ {
+ indices.add(swizzledStore->getElementIndex(i));
+ }
+ auto newValue = builder.emitSwizzleSet(
+ imageElementType,
+ originalValue,
+ swizzledStore->getSource(),
+ swizzledStore->getElementCount(),
+ indices.getBuffer());
+ if (getIRVectorElementSize(imageElementType) != 4)
+ {
+ auto vectorBaseType = getIRVectorBaseType(imageElementType);
+ newValue = builder.emitConstructorInst(
+ builder.getVectorType(
+ vectorBaseType, builder.getIntValue(builder.getIntType(), 4)),
+ 1,
+ &newValue);
+ }
+ auto imageStore = builder.emitImageStore(
+ builder.getVoidType(), imageSubscript->getImage(), legalizedCoord, newValue);
+ storeInst->replaceUsesWith(imageStore);
+ storeInst->removeAndDeallocate();
+ if (!imageSubscript->hasUses())
+ {
+ imageSubscript->removeAndDeallocate();
+ }
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+void legalizeImageSubscriptForGLSL(IRModule* module)
+{
+ SharedIRBuilder shared(module);
+ IRBuilder builder(shared);
+ for (auto globalInst : module->getModuleInst()->getChildren())
+ {
+ auto func = as<IRFunc>(globalInst);
+ if (!func)
+ continue;
+ for (auto block : func->getBlocks())
+ {
+ auto inst = block->getFirstInst();
+ IRInst* next;
+ for ( ; inst; inst = next)
+ {
+ next = inst->getNextInst();
+ switch (inst->getOp())
+ {
+ case kIROp_Store:
+ case kIROp_SwizzledStore:
+ if (inst->getOperand(0)->getOp() == kIROp_ImageSubscript)
+ {
+ legalizeImageSubscriptStoreForGLSL(builder, inst);
+ }
+ }
+ }
+ }
+ }
+}
+
+ //
// Legalization of entry points for GLSL:
//
@@ -1341,6 +1474,16 @@ void legalizeEntryPointParameterForGLSL(
}
}
+ if (stage == Stage::Geometry)
+ {
+ // If the user provided no parameters with a input primitive type qualifier, we
+ // default to `triangle`.
+ if (!func->findDecoration<IRGeometryInputPrimitiveTypeDecoration>())
+ {
+ builder->addDecoration(func, kIROp_TriangleInputPrimitiveTypeDecoration);
+ }
+ }
+
// There *can* be multiple streamout parameters, to an entry point (points if nothing else)
{
IRType* type = pp->getFullType();
diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h
index 0c6bf5196..920715be2 100644
--- a/source/slang/slang-ir-glsl-legalize.h
+++ b/source/slang/slang-ir-glsl-legalize.h
@@ -20,4 +20,6 @@ void legalizeEntryPointsForGLSL(
DiagnosticSink* sink,
GLSLExtensionTracker* glslExtensionTracker);
+void legalizeImageSubscriptForGLSL(IRModule* module);
+
}
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index a8347438b..ed4b34875 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -1,6 +1,8 @@
// slang-ir-inline.cpp
#include "slang-ir-inline.h"
+#include "slang-ir-ssa-simplification.h"
+
// This file provides general facilities for inlining function calls.
//
@@ -33,17 +35,18 @@ struct InliningPassBase
}
/// Consider all the call sites in the module for inliing
- void considerAllCallSites()
+ bool considerAllCallSites()
{
- considerAllCallSitesRec(m_module->getModuleInst());
+ return considerAllCallSitesRec(m_module->getModuleInst());
}
/// Consider all call sites at or under `inst` for inlining
- void considerAllCallSitesRec(IRInst* inst)
+ bool considerAllCallSitesRec(IRInst* inst)
{
+ bool changed = false;
if( auto call = as<IRCall>(inst) )
{
- considerCallSite(call);
+ changed = considerCallSite(call);
}
// Note: we defensively iterate through the child instructions
@@ -54,8 +57,9 @@ struct InliningPassBase
for( auto child = inst->getFirstChild(); child; child = next )
{
next = child->getNextInst();
- considerAllCallSitesRec(child);
+ changed |= considerAllCallSitesRec(child);
}
+ return changed;
}
// In order to inline a call site, we need certain information
@@ -93,7 +97,7 @@ struct InliningPassBase
// basic proces of considering a call site for inlining.
/// Consider the given `call` site, and possibly inline it.
- void considerCallSite(IRCall* call)
+ bool considerCallSite(IRCall* call)
{
// We start by checking if inlining would even be possible,
// since doing so collects information about the call site
@@ -104,7 +108,7 @@ struct InliningPassBase
//
CallSiteInfo callSite;
if(!canInline(call, callSite))
- return;
+ return false;
// If we've decided that we *can* inline the given call
// site, we next need to check if we *should*. The rules
@@ -112,13 +116,14 @@ struct InliningPassBase
// so `shouldInline` is a virtual method.
//
if(!shouldInline(callSite))
- return;
+ return false;
// Finally, if we both *can* and *should* inline the
// given call site, we hand off the a worker routine
// that does the meat of the work.
//
inlineCallSite(callSite);
+ return true;
}
// Every subclas of `InliningPassBase` should provide its own
@@ -313,12 +318,11 @@ struct InliningPassBase
}
// For now, our inlining pass only handles the case where
- // the callee is a "trivial" function, which can support
- // a very simple approach to inlining.
- //
- if( isTrivialFunc(callee) )
+ // the callee is a "single-return" function, which means the callee
+ // function contains only one return at the end of the body.
+ if (isSingleReturnFunc(callee))
{
- inlineTrivialFuncBody(callSite, &env, &builder);
+ inlineSingleReturnFuncBody(callSite, &env, &builder);
}
else
{
@@ -329,26 +333,34 @@ struct InliningPassBase
}
}
- /// Check if `func` represents a trivial single-block callee that can be inlined simply
- bool isTrivialFunc(IRFunc* func)
+ /// Check if `func` represents a simple callee that has only a single `return`.
+ bool isSingleReturnFunc(IRFunc* func)
{
- // The function must have a single bocy block to be trivial.
- //
auto firstBlock = func->getFirstBlock();
- if( firstBlock->getNextBlock() )
- return false;
// If the body block is decorated (for some reason), then the function is non-trivial.
//
if( firstBlock->getFirstDecoration() )
return false;
- // If the body block terminates in something other than a `return` then the function is non-trivial.
- //
- auto terminator = firstBlock->getTerminator();
- if( !as<IRReturn>(terminator) )
- return false;
-
+ // If the body has more than one returns, we cannot inline it now.
+ bool returnFound = false;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_ReturnVal || inst->getOp() == kIROp_ReturnVoid)
+ {
+ // If the return is not at the end of the block, we cannot handle it.
+ if (inst != block->getTerminator())
+ return false;
+ // If there is already a return found, this function cannot be simple.
+ if (returnFound)
+ return false;
+ returnFound = true;
+ }
+ }
+ }
return true;
}
@@ -374,13 +386,8 @@ struct InliningPassBase
// serialization.
//
// For now this punts on this, and just assumes [__unsafeForceInlineEarly] is not in user code.
- static IRInst* _cloneInstWithSourceLoc(CallSiteInfo const& callSite,
- IRCloneEnv* env,
- IRBuilder* builder,
- IRInst* inst)
+ static void _setSourceLoc(IRInst* clonedInst, IRInst* srcInst, CallSiteInfo const& callSite)
{
- IRInst* clonedInst = cloneInst(env, builder, inst);
-
SourceLoc sourceLoc;
if (callSite.call->sourceLoc.isValid())
@@ -388,67 +395,129 @@ struct InliningPassBase
// Default to using the source loc at the call site
sourceLoc = callSite.call->sourceLoc;
}
- else if (inst->sourceLoc.isValid())
+ else if (srcInst->sourceLoc.isValid())
{
// If we don't have that copy the inst being cloned sourceLoc
- sourceLoc = inst->sourceLoc;
+ sourceLoc = srcInst->sourceLoc;
}
clonedInst->sourceLoc = sourceLoc;
+ }
+ static IRInst* _cloneInstWithSourceLoc(CallSiteInfo const& callSite,
+ IRCloneEnv* env,
+ IRBuilder* builder,
+ IRInst* inst)
+ {
+ IRInst* clonedInst = cloneInst(env, builder, inst);
+ _setSourceLoc(clonedInst, inst, callSite);
return clonedInst;
}
- /// Inline the body of the callee for `callSite`, where the callee is trivial as tested by `isTrivialFunc`
- void inlineTrivialFuncBody(CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder)
+ /// Inline the body of the callee for `callSite`, where the callee has a single return.
+ void inlineSingleReturnFuncBody(
+ CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder)
{
auto call = callSite.call;
auto callee = callSite.callee;
- auto firstBlock = callee->getFirstBlock();
- // We know that the callee has a single block, so if we encounter
+ // We know that the callee has a single return block, so if we encounter
// a `returnVal` instruction then it must be the one and only
- // return point for the block, and its operand will be the value
- // the calee returns.
+ // return point for the function, and its operand will be the value
+ // the callee returns.
//
IRInst* returnedValue = nullptr;
- // We will loop over the instructions of the one and only block,
- // and clone each of them appropriately.
- //
- for( auto inst : firstBlock->getChildren() )
+ // Break the basic block containing the call inst into two basic blocks.
+ auto callerBlock = callSite.call->getParent();
+ builder->setInsertInto(callerBlock->getParent());
+ auto afterBlock = builder->createBlock();
+
+ // Many operations (e.g. `cloneInst`) has define-before-use assumptions on the IR.
+ // It is important to make sure we keep the ordering of blocks by inserting the
+ // second half of the basic block right after `callerBlock`.
+ afterBlock->insertAfter(callerBlock);
+ afterBlock->sourceLoc = callSite.call->getNextInst()->sourceLoc;
+ // Move all insts after the call in `callerBlock` to `afterBlock`.
{
- switch( inst->getOp() )
+ auto inst = callSite.call->getNextInst();
+ while (inst)
{
- default:
- // The default value is to clone the instruction using
- // the existing cloning infrastructure and the `env`
- // we have already set up.
- //
- // SourceLoc information is copied if there is appropriate data available.
- _cloneInstWithSourceLoc(callSite, env, builder, inst);
- break;
-
- case kIROp_Param:
- // Parameters can be completely ignored in the single-block
- // case, because they have all been replaced via `env`.
- break;
-
- case kIROp_ReturnVoid:
- // A return with no operand can be ignored, since a return
- // from the inlined call should just continue after the
- // call site.
- //
- break;
-
- case kIROp_ReturnVal:
- // A return with a value is similar to `returnVoid` except
- // that we need to note the (clone of the) value being
- // returned, so that we can use it to replace the value
- // of the original call.
- //
- returnedValue = findCloneForOperand(env, inst->getOperand(0));
- break;
+ auto next = inst->getNextInst();
+ inst->removeFromParent();
+ inst->insertAtEnd(afterBlock);
+ inst = next;
+ }
+ }
+
+ List<IRBlock*> clonedBlocks;
+ for (auto calleeBlock : callee->getBlocks())
+ {
+ auto clonedBlock = builder->createBlock();
+ clonedBlock->insertBefore(afterBlock);
+ _setSourceLoc(clonedBlock, calleeBlock, callSite);
+ env->mapOldValToNew[calleeBlock] = clonedBlock;
+ }
+
+ // Insert a branch into the cloned first block at the end of `callerBlock`.
+ builder->setInsertInto(callerBlock);
+ auto newBranch = builder->emitBranch(as<IRBlock>(env->mapOldValToNew[callee->getFirstBlock()].GetValue()));
+ _setSourceLoc(newBranch, call, callSite);
+ // Clone all basic blocks over to the call site.
+ bool isFirstBlock = true;
+ for (auto calleeBlock : callee->getBlocks())
+ {
+ auto clonedBlock = env->mapOldValToNew[calleeBlock].GetValue();
+ builder->setInsertInto(clonedBlock);
+ // We will loop over the instructions of the each block,
+ // and clone each of them appropriately.
+ //
+ for (auto inst : calleeBlock->getChildren())
+ {
+ if (inst->getOp() == kIROp_Param)
+ {
+ // Parameters in the first block can be completely ignored
+ // because they have all been replaced via `env`.
+ if (isFirstBlock)
+ {
+ continue;
+ }
+ }
+
+ switch (inst->getOp())
+ {
+ default:
+ // The default value is to clone the instruction using
+ // the existing cloning infrastructure and the `env`
+ // we have already set up.
+ //
+ // SourceLoc information is copied if there is appropriate data available.
+ _cloneInstWithSourceLoc(callSite, env, builder, inst);
+ break;
+
+ case kIROp_ReturnVoid:
+ // A return with no operand is replaced with a branch into `afterBlock`
+ // to return the control flow to the location after the original `call`.
+ {
+ auto returnBranch = builder->emitBranch(afterBlock);
+ _setSourceLoc(returnBranch, inst, callSite);
+ }
+ break;
+
+ case kIROp_ReturnVal:
+ // A return with a value is similar to `returnVoid` except
+ // that we need to note the (clone of the) value being
+ // returned, so that we can use it to replace the value
+ // of the original call.
+ //
+ {
+ auto returnBranch = builder->emitBranch(afterBlock);
+ _setSourceLoc(returnBranch, inst, callSite);
+ returnedValue = findCloneForOperand(env, inst->getOperand(0));
+ }
+ break;
+ }
}
+ isFirstBlock = false;
}
// If there was a `returnVal` instruction that established
@@ -492,4 +561,73 @@ void performMandatoryEarlyInlining(IRModule* module)
pass.considerAllCallSites();
}
+
+ // Defined in slang-ir-specialize-resource.cpp
+bool isResourceType(IRType* type);
+bool isIllegalGLSLParameterType(IRType* type);
+
+ /// An inlining pass that inlines calls functions that returns resources.
+ /// This is needed for glsl targets.
+struct GLSLResourceReturnFunctionInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ GLSLResourceReturnFunctionInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (isResourceType(info.callee->getResultType()))
+ {
+ return true;
+ }
+ for (auto param : info.callee->getParams())
+ {
+ if (isIllegalGLSLParameterType(param->getDataType()))
+ return true;
+ auto outType = as<IROutTypeBase>(param->getDataType());
+ if (!outType)
+ continue;
+ auto outValueType = outType->getValueType();
+ if (isResourceType(outValueType))
+ return true;
+ }
+ return false;
+ }
+};
+
+void performGLSLResourceReturnFunctionInlining(IRModule* module)
+{
+ GLSLResourceReturnFunctionInliningPass pass(module);
+ bool changed = true;
+
+ while (changed)
+ {
+ changed = pass.considerAllCallSites();
+ simplifyIR(module);
+ }
+}
+
+struct CustomInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ CustomInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const&)
+ {
+ return true;
+ }
+};
+
+bool inlineCall(IRCall* call)
+{
+ CustomInliningPass pass(call->getModule());
+ return pass.considerCallSite(call);
+}
+
+
} // namespace Slang
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 322b4c855..8ac23f6b0 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -4,7 +4,14 @@
namespace Slang
{
struct IRModule;
+ struct IRCall;
/// Inline any call sites to functions marked `[unsafeForceInlineEarly]`
void performMandatoryEarlyInlining(IRModule* module);
+
+ /// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
+ void performGLSLResourceReturnFunctionInlining(IRModule* module);
+
+ /// Inline a specific call.
+ bool inlineCall(IRCall* call);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 0fca118d1..a3486ee68 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -296,6 +296,11 @@ INST(getAddr, getAddr, 1, 0)
// "Subscript" an image at a pixel coordinate to get pointer
INST(ImageSubscript, imageSubscript, 2, 0)
+// Load from an Image.
+INST(ImageLoad, imageLoad, 2, 0)
+// Store into an Image.
+INST(ImageStore, imageStore, 3, 0)
+
// Load (almost) arbitrary-type data from a byte-address buffer
//
// %dst = byteAddressBufferLoad(%buffer, %offset)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9b50047de..09a3bdbb3 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1368,6 +1368,27 @@ struct IRGetAddress : IRInst
IR_LEAF_ISA(getAddr);
};
+struct IRImageSubscript : IRInst
+{
+ IR_LEAF_ISA(ImageSubscript);
+ IRInst* getImage() { return getOperand(0); }
+ IRInst* getCoord() { return getOperand(1); }
+};
+
+struct IRImageLoad : IRInst
+{
+ IR_LEAF_ISA(ImageLoad);
+ IRInst* getImage() { return getOperand(0); }
+ IRInst* getCoord() { return getOperand(1); }
+};
+
+struct IRImageStore : IRInst
+{
+ IR_LEAF_ISA(ImageStore);
+ IRInst* getImage() { return getOperand(0); }
+ IRInst* getCoord() { return getOperand(1); }
+ IRInst* getValue() { return getOperand(2); }
+};
// Terminators
struct IRReturn : IRTerminatorInst
@@ -2413,6 +2434,17 @@ public:
IRInst* dstPtr,
IRInst* srcVal);
+ IRInst* emitImageLoad(
+ IRType* type,
+ IRInst* image,
+ IRInst* coord);
+
+ IRInst* emitImageStore(
+ IRType* type,
+ IRInst* image,
+ IRInst* coord,
+ IRInst* value);
+
IRInst* emitFieldExtract(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index 159cf9abc..463de29ca 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -109,6 +109,41 @@ struct SCCPContext
}
};
+ static bool isEvaluableOpCode(IROp op)
+ {
+ switch (op)
+ {
+ case kIROp_IntLit:
+ case kIROp_BoolLit:
+ case kIROp_FloatLit:
+ case kIROp_StringLit:
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Neg:
+ case kIROp_Not:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitNot:
+ case kIROp_BitCast:
+ case kIROp_Construct:
+ case kIROp_Select:
+ return true;
+ default:
+ return false;
+ }
+ }
+
// If we imagine a variable (actually an SSA phi node...) that
// might be assigned lattice value A at one point in the code,
// and lattice value B at another point, we need a way to
@@ -204,20 +239,29 @@ struct SCCPContext
break;
}
- // We might be asked for the lattice value of an instruction
- // not contained in the current function. When that happens,
- // we will treat it as having potentially any value, rather
- // than the default of none.
- //
- auto parentBlock = as<IRBlock>(inst->getParent());
- if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny();
-
- // Once the special cases are dealt with, we can look up in
- // the dictionary and just return the value we get from it,
- // or default to the `None` (empty set) case.
+ // Look up in the dictionary and just return the value we get from it.
LatticeVal latticeVal;
if(mapInstToLatticeVal.TryGetValue(inst, latticeVal))
return latticeVal;
+
+ // If we can't find the value from dictionary, we want to return None if this is a value
+ // in the same function as the one we are working with right now. If it is defined
+ // elsewhere, we return Any.
+ auto parentBlock = as<IRBlock>(inst->getParent());
+ bool isProcessingGlobalScope = (code == nullptr);
+ if (!parentBlock && isProcessingGlobalScope)
+ {
+ // We are folding constant in the global scope, continue registering the inst as Any.
+ }
+ else
+ {
+ // If we are processing a function and asked for the lattice value of an instruction
+ // not contained in the current function, we will treat it as having potentially any
+ // value, rather than the default of none.
+ //
+ if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny();
+ }
+
return LatticeVal::getNone();
}
@@ -228,6 +272,460 @@ struct SCCPContext
IRBuilder builderStorage;
IRBuilder* getBuilder() { return &builderStorage; }
+ // LatticeVal constant evaluation methods.
+#define SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v) \
+ switch (v.flavor) \
+ { \
+ case LatticeVal::Flavor::None: \
+ return LatticeVal::getNone(); \
+ case LatticeVal::Flavor::Any: \
+ return LatticeVal::getAny(); \
+ default: \
+ break; \
+ }
+
+ LatticeVal evalConstruct(IRType* type, LatticeVal v0)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto irConstant = as<IRConstant>(v0.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ switch (irConstant->getOp())
+ {
+ case kIROp_FloatLit:
+ resultVal =
+ getBuilder()->getIntValue(type, (IRIntegerValue)irConstant->value.floatVal);
+ break;
+ case kIROp_IntLit:
+ case kIROp_BoolLit:
+ {
+ IRIntegerValue intVal = irConstant->value.intVal;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ intVal = intVal & 0xFF;
+ break;
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ intVal = intVal & 0xFFFF;
+ break;
+ case kIROp_IntType:
+ case kIROp_UIntType:
+ case kIROp_BoolType:
+ intVal = intVal & 0xFFFFFFFF;
+ break;
+ default:
+ break;
+ }
+ resultVal = getBuilder()->getIntValue(type, (IRIntegerValue)intVal);
+ }
+ break;
+ default:
+ return LatticeVal::getAny();
+ }
+ break;
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_HalfType:
+ switch (irConstant->getOp())
+ {
+ case kIROp_FloatLit:
+ resultVal = getBuilder()->getFloatValue(
+ type, (IRFloatingPointValue)irConstant->value.floatVal);
+ break;
+ case kIROp_IntLit:
+ case kIROp_BoolLit:
+ resultVal = getBuilder()->getFloatValue(
+ type, (IRFloatingPointValue)irConstant->value.intVal);
+ break;
+ default:
+ return LatticeVal::getAny();
+ }
+ break;
+ case kIROp_BoolType:
+ switch (irConstant->getOp())
+ {
+ case kIROp_FloatLit:
+ resultVal = getBuilder()->getBoolValue(irConstant->value.floatVal != 0);
+ break;
+ case kIROp_IntLit:
+ case kIROp_BoolLit:
+ {
+ resultVal = getBuilder()->getBoolValue(irConstant->value.intVal != 0);
+ }
+ break;
+ default:
+ return LatticeVal::getAny();
+ }
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ template<typename TIntFunc, typename TFloatFunc>
+ LatticeVal evalBinaryImpl(
+ IRType* type,
+ LatticeVal v0,
+ LatticeVal v1,
+ const TIntFunc& intFunc,
+ const TFloatFunc& floatFunc)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
+ auto c1 = as<IRConstant>(v1.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_BoolType:
+ resultVal = getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal));
+ break;
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_HalfType:
+ resultVal = getBuilder()->getFloatValue(type, floatFunc(c0->value.floatVal, c1->value.floatVal));
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ template <typename TIntFunc>
+ LatticeVal evalBinaryIntImpl(
+ IRType* type,
+ LatticeVal v0,
+ LatticeVal v1,
+ const TIntFunc& intFunc)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
+ auto c1 = as<IRConstant>(v1.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_BoolType:
+ resultVal =
+ getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal));
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ template <typename TIntFunc>
+ LatticeVal evalUnaryIntImpl(
+ IRType* type, LatticeVal v0, const TIntFunc& intFunc)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_BoolType:
+ resultVal =
+ getBuilder()->getIntValue(type, intFunc(c0->value.intVal));
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ template <typename TIntFunc, typename TFloatFunc>
+ LatticeVal evalComparisonImpl(
+ IRType* type,
+ LatticeVal v0,
+ LatticeVal v1,
+ const TIntFunc& intFunc,
+ const TFloatFunc& floatFunc)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
+ auto c1 = as<IRConstant>(v1.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_BoolType:
+ resultVal =
+ getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
+ break;
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_HalfType:
+ resultVal =
+ getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal));
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ LatticeVal evalAdd(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 + c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 + c1; });
+ }
+ LatticeVal evalSub(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 - c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 - c1; });
+ }
+ LatticeVal evalMul(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 * c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 * c1; });
+ }
+ LatticeVal evalDiv(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 / c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 / c1; });
+ }
+ LatticeVal evalEql(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 == c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 == c1; });
+ }
+ LatticeVal evalNeq(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 != c1; });
+ }
+ LatticeVal evalGeq(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >= c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 >= c1; });
+ }
+ LatticeVal evalLeq(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 <= c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 <= c1; });
+ }
+ LatticeVal evalGreater(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 > c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 > c1; });
+ }
+ LatticeVal evalLess(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalComparisonImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 < c1; },
+ [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 < c1; });
+ }
+ LatticeVal evalAnd(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type,
+ v0,
+ v1,
+ [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 && c1 != 0; });
+ }
+ LatticeVal evalOr(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 || c1 != 0; });
+ }
+ LatticeVal evalNot(IRType* type, LatticeVal v0)
+ {
+ return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return c0 == 0; });
+ }
+ LatticeVal evalBitAnd(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 & c1; });
+ }
+ LatticeVal evalBitOr(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 | c1; });
+ }
+ LatticeVal evalBitNot(IRType* type, LatticeVal v0)
+ {
+ return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return ~c0; });
+ }
+ LatticeVal evalBitXor(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 ^ c1; });
+ }
+ LatticeVal evalLsh(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 << c1; });
+ }
+ LatticeVal evalRsh(IRType* type, LatticeVal v0, LatticeVal v1)
+ {
+ return evalBinaryIntImpl(
+ type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >> c1; });
+ }
+ LatticeVal evalNeg(IRType* type, LatticeVal v0)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ resultVal = getBuilder()->getIntValue(type, -c0->value.intVal);
+ break;
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_HalfType:
+ resultVal = getBuilder()->getFloatValue(type, -c0->value.floatVal);
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ LatticeVal evalBitCast(IRType* type, LatticeVal v0)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ IRInst* resultVal = nullptr;
+ switch (type->getOp())
+ {
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+ resultVal = getBuilder()->getIntValue(type, c0->value.intVal);
+ break;
+ case kIROp_IntType:
+ case kIROp_UIntType:
+ {
+ float val = (float)c0->value.floatVal;
+ uint32_t intVal = (uint32_t)FloatAsInt(val);
+ resultVal = getBuilder()->getIntValue(type, intVal);
+ }
+ break;
+ case kIROp_FloatType:
+ {
+ uint32_t val = (uint32_t)c0->value.intVal;
+ float floatVal = IntAsFloat((int)val);
+ resultVal = getBuilder()->getFloatValue(type, floatVal);
+ }
+ break;
+ case kIROp_DoubleType:
+ resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(c0->value.intVal));
+ break;
+ default:
+ break;
+ }
+ if (!resultVal)
+ return LatticeVal::getAny();
+ return LatticeVal::getConstant(resultVal);
+ }
+
+ LatticeVal evalSelect(LatticeVal v0, LatticeVal v1, LatticeVal v2)
+ {
+ SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
+ auto c0 = as<IRConstant>(v0.value);
+ return c0->value.intVal != 0 ? v1 : v2;
+ }
+
// In order to perform constant folding, we need to be able to
// interpret an instruction over the lattice values.
//
@@ -245,11 +743,17 @@ struct SCCPContext
case kIROp_BoolLit:
return LatticeVal::getConstant(inst);
- // TODO: we might also want to special-case certain
+ // We might also want to special-case certain
// instructions where we shouldn't bother trying to
// constant-fold them and should just default to the
// `Any` value right away.
-
+ case kIROp_Call:
+ case kIROp_ByteAddressBufferLoad:
+ case kIROp_ByteAddressBufferStore:
+ case kIROp_Alloca:
+ case kIROp_Store:
+ case kIROp_Load:
+ return LatticeVal::getAny();
default:
break;
}
@@ -288,10 +792,116 @@ struct SCCPContext
// `None` inputs as producing `Any` to make sure we don't
// optimize the code based on non-obvious assumptions.
//
- // For now we aren't implementing *any* folding logic here,
- // for simplicity. This is the right place to add folding
- // optimizations if/when we need them.
- //
+ // For now we implement only basic folding operations for
+ // scalar values.
+ if (!as<IRBasicType>(inst->getDataType()))
+ return LatticeVal::getAny();
+
+ switch (inst->getOp())
+ {
+ case kIROp_Construct:
+ return evalConstruct(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
+ case kIROp_Add:
+ return evalAdd(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Sub:
+ return evalSub(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Mul:
+ return evalMul(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Div:
+ return evalDiv(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Eql:
+ return evalEql(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Neq:
+ return evalNeq(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Greater:
+ return evalGreater(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Less:
+ return evalLess(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Leq:
+ return evalLeq(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Geq:
+ return evalGeq(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_And:
+ return evalAnd(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Or:
+ return evalOr(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Not:
+ return evalNot(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
+ case kIROp_BitAnd:
+ return evalBitAnd(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_BitOr:
+ return evalBitOr(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_BitNot:
+ return evalBitNot(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
+ case kIROp_BitXor:
+ return evalBitXor(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_BitCast:
+ return evalBitCast(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
+ case kIROp_Neg:
+ return evalNeg(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
+ case kIROp_Lsh:
+ return evalLsh(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Rsh:
+ return evalRsh(
+ inst->getDataType(),
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)));
+ case kIROp_Select:
+ return evalSelect(
+ getLatticeVal(inst->getOperand(0)),
+ getLatticeVal(inst->getOperand(1)),
+ getLatticeVal(inst->getOperand(2)));
+ default:
+ break;
+ }
// A safe default is to assume that every instruction not
// handled by one of the cases above could produce *any*
@@ -567,10 +1177,60 @@ struct SCCPContext
}
}
+ // Run the constant folding on global scope only.
+ bool applyOnGlobalScope(IRModule* module)
+ {
+ builderStorage.init(shared->sharedBuilder);
+ for (auto child : module->getModuleInst()->getChildren())
+ {
+ // Only consider evaluable opcodes.
+ if (!isEvaluableOpCode(child->getOp()))
+ continue;
+
+ updateValueForInst(child);
+ }
+ while (ssaWorkList.getCount())
+ {
+ auto inst = ssaWorkList[0];
+ ssaWorkList.fastRemoveAt(0);
+ // Only consider evaluable opcodes and insts at global scope.
+ if (!isEvaluableOpCode(inst->getOp()) || inst->getParent() != module->getModuleInst())
+ continue;
+ updateValueForInst(inst);
+ }
+
+ bool changed = false;
+ // Replace the insts with their values.
+ List<IRInst*> instsToRemove;
+ for (auto child : module->getModuleInst()->getChildren())
+ {
+ if (!isEvaluableOpCode(child->getOp()))
+ continue;
+
+ auto latticeVal = getLatticeVal(child);
+ if (latticeVal.flavor == LatticeVal::Flavor::Constant && latticeVal.value != child)
+ {
+ child->replaceUsesWith(latticeVal.value);
+ instsToRemove.add(child);
+ }
+ }
+
+ if (instsToRemove.getCount())
+ {
+ changed = true;
+ for (auto inst : instsToRemove)
+ inst->removeAndDeallocate();
+ // Rebuild global value map.
+ builderStorage.getSharedBuilder()->deduplicateAndRebuildGlobalNumberingMap();
+ }
+ return changed;
+ }
+
// The `apply()` function will run the full algorithm.
//
- void apply()
+ bool apply()
{
+ bool changed = false;
// We start with the busy-work of setting up our IR builder.
//
builderStorage.init(shared->sharedBuilder);
@@ -733,6 +1393,9 @@ struct SCCPContext
}
}
+ if (instsToRemove.getCount() != 0)
+ changed = true;
+
// Once we've replaced the uses of instructions that evaluate
// to constants, we make a second pass to remove the instructions
// themselves (or at least those without side effects).
@@ -786,6 +1449,7 @@ struct SCCPContext
builder->setInsertBefore(terminator);
builder->emitBranch(target);
terminator->removeAndDeallocate();
+ changed = true;
}
}
else if(auto condBranchInst = as<IRConditionalBranch>(terminator))
@@ -800,6 +1464,7 @@ struct SCCPContext
builder->setInsertBefore(terminator);
builder->emitBranch(target);
terminator->removeAndDeallocate();
+ changed = true;
}
}
@@ -911,38 +1576,52 @@ struct SCCPContext
builder->emitUnreachable();
}
}
+ return changed;
}
};
-static void applySparseConditionalConstantPropagationRec(
- SharedSCCPContext* shared,
+static bool applySparseConditionalConstantPropagationRec(
+ const SCCPContext& globalContext,
IRInst* inst)
{
+ bool changed = false;
if( auto code = as<IRGlobalValueWithCode>(inst) )
{
if( code->getFirstBlock() )
{
SCCPContext context;
- context.shared = shared;
+ context.shared = globalContext.shared;
context.code = code;
- context.apply();
+ context.mapInstToLatticeVal = globalContext.mapInstToLatticeVal;
+ changed |= context.apply();
}
}
for( auto childInst : inst->getDecorationsAndChildren() )
{
- applySparseConditionalConstantPropagationRec(shared, childInst);
+ changed |= applySparseConditionalConstantPropagationRec(globalContext, childInst);
}
+ return changed;
}
-void applySparseConditionalConstantPropagation(
+bool applySparseConditionalConstantPropagation(
IRModule* module)
{
SharedSCCPContext shared;
shared.module = module;
shared.sharedBuilder.init(module);
+ shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ // First we fold constants at global scope.
+ SCCPContext globalContext;
+ globalContext.shared = &shared;
+ globalContext.code = nullptr;
+ bool changed = globalContext.applyOnGlobalScope(module);
+
+ // Now run recursive SCCP passes on each child code block.
+ changed |= applySparseConditionalConstantPropagationRec(globalContext, module->getModuleInst());
- applySparseConditionalConstantPropagationRec(&shared, module->getModuleInst());
+ return changed;
}
}
diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h
index b557eefe3..06c5769c8 100644
--- a/source/slang/slang-ir-sccp.h
+++ b/source/slang/slang-ir-sccp.h
@@ -12,7 +12,8 @@ namespace Slang
/// also eliminates conditional branches where the condition will
/// always evaluate to a constant (which can lead to entire blocks
/// becoming dead code)
- void applySparseConditionalConstantPropagation(
+ /// Returns true if IR is changed.
+ bool applySparseConditionalConstantPropagation(
IRModule* module);
}
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
new file mode 100644
index 000000000..af0e7c0ce
--- /dev/null
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -0,0 +1,82 @@
+#include "slang-ir-simplify-cfg.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+bool processFunc(IRFunc* func)
+{
+ auto firstBlock = func->getFirstBlock();
+ if (!firstBlock)
+ return false;
+
+ bool changed = false;
+
+ List<IRBlock*> workList;
+ HashSet<IRBlock*> processedBlock;
+ workList.add(func->getFirstBlock());
+ while (workList.getCount())
+ {
+ auto block = workList.getFirst();
+ workList.fastRemoveAt(0);
+ while (block)
+ {
+ // If `block` does not end with an unconditional branch, bail.
+ if (block->getTerminator()->getOp() != kIROp_unconditionalBranch)
+ break;
+ auto branch = as<IRUnconditionalBranch>(block->getTerminator());
+ auto successor = branch->getTargetBlock();
+ // Only perform the merge if `block` is the only predecessor of `successor`.
+ // We also need to make sure not to merge a block that serves as the
+ // merge point in CFG. Such blocks will have more than one use.
+ if (successor->hasMoreThanOneUse())
+ break;
+ changed = true;
+ Index paramIndex = 0;
+ auto inst = successor->getFirstDecorationOrChild();
+ while (inst)
+ {
+ auto next = inst->getNextInst();
+ if (inst->getOp() == kIROp_Param)
+ {
+ inst->replaceUsesWith(branch->getArg(paramIndex));
+ paramIndex++;
+ }
+ else
+ {
+ inst->removeFromParent();
+ inst->insertAtEnd(block);
+ }
+ inst = next;
+ }
+ branch->removeAndDeallocate();
+ assert(!successor->hasUses());
+ successor->removeAndDeallocate();
+ }
+ for (auto successor : block->getSuccessors())
+ {
+ if (processedBlock.Add(successor))
+ {
+ workList.add(successor);
+ }
+ }
+ }
+ return changed;
+}
+
+bool simplifyCFG(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(inst))
+ {
+ changed |= processFunc(func);
+ }
+ }
+ return changed;
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h
new file mode 100644
index 000000000..3d8729274
--- /dev/null
+++ b/source/slang/slang-ir-simplify-cfg.h
@@ -0,0 +1,12 @@
+// slang-ir-simplify-cfg.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ /// Simplifies control flow graph by merging basic blocks that
+ /// forms a simple linear chain.
+ /// Returns true if changed.
+ bool simplifyCFG(IRModule* module);
+}
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index eac5287e5..bcd7b494f 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -123,14 +123,14 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
// the witness table sequential ID passed in.
builder->setInsertInto(newDispatchFunc);
-
+
if (witnessTables.getCount() == 1)
{
// If there is only 1 case, no switch statement is necessary.
builder->setInsertInto(newBlock);
builder->emitBranch(defaultBlock);
}
- else
+ else if (witnessTables.getCount() > 1)
{
auto breakBlock = builder->emitBlock();
builder->setInsertInto(breakBlock);
@@ -144,6 +144,21 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
caseBlocks.getCount(),
caseBlocks.getBuffer());
}
+ else
+ {
+ // We have no witness tables that implements this interface.
+ // Just return a default value.
+ builder->setInsertInto(newBlock);
+ if (callInst->getDataType()->getOp() == kIROp_VoidType)
+ {
+ builder->emitReturn();
+ }
+ else
+ {
+ auto defaultValue = builder->emitConstructorInst(callInst->getDataType(), 0, nullptr);
+ builder->emitReturn(defaultValue);
+ }
+ }
// Remove old implementation.
dispatchFunc->replaceUsesWith(newDispatchFunc);
dispatchFunc->removeAndDeallocate();
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index ab77ec88e..1e63be890 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -100,7 +100,7 @@ struct FunctionParameterSpecializationContext
// With the basic state out of the way, let's walk
// through the overall flow of the pass.
//
- void processModule()
+ bool processModule()
{
// We will start by initializing our IR building state.
//
@@ -112,6 +112,8 @@ struct FunctionParameterSpecializationContext
//
addCallsToWorkListRec(module->getModuleInst());
+ bool changed = false;
+
// We will process the work list until it goes dry,
// treating it like a stack of work items.
//
@@ -130,8 +132,10 @@ struct FunctionParameterSpecializationContext
if( canSpecializeCall(call) )
{
specializeCall(call);
+ changed = true;
}
}
+ return changed;
}
// Setting up the work list is a simple recursive procedure.
@@ -353,6 +357,7 @@ struct FunctionParameterSpecializationContext
// we need to generate a call to it, and then use the new
// call as a replacement for the old one.
//
+ SLANG_ASSERT(newFunc != oldCall->getCallee());
auto newCall = getBuilder()->emitCallInst(
oldCall->getFullType(),
newFunc,
@@ -877,7 +882,7 @@ struct FunctionParameterSpecializationContext
// is straighforward. We set up the context object
// and then defer to it for the real work.
//
-void specializeFunctionCalls(
+bool specializeFunctionCalls(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module,
@@ -889,7 +894,7 @@ void specializeFunctionCalls(
context.module = module;
context.condition = condition;
- context.processModule();
+ return context.processModule();
}
} // namesapce Slang
diff --git a/source/slang/slang-ir-specialize-function-call.h b/source/slang/slang-ir-specialize-function-call.h
index 868f9def2..4afb0526f 100644
--- a/source/slang/slang-ir-specialize-function-call.h
+++ b/source/slang/slang-ir-specialize-function-call.h
@@ -26,8 +26,8 @@ namespace Slang
/// a specialized variant of the function that does not have
/// those resource parameters (and instead, e.g, refers to the
/// global shader parameters directly).
- ///
- void specializeFunctionCalls(
+ /// Returns true if any changes are made.
+ bool specializeFunctionCalls(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module,
diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp
index 1b3fc65f4..01e1c396c 100644
--- a/source/slang/slang-ir-specialize-resources.cpp
+++ b/source/slang/slang-ir-specialize-resources.cpp
@@ -6,6 +6,9 @@
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
+#include "slang-ir-ssa-simplification.h"
+
+#include "slang-ir-inline.h"
namespace Slang
{
@@ -55,14 +58,11 @@ struct ResourceParameterSpecializationCondition : FunctionCallSpecializeConditio
// For GL/Vulkan targets, we also need to specialize
// any parameters that use structured or byte-addressed
- // buffers.
+ // buffers or images with format qualifiers.
//
if( isKhronosTarget(targetRequest) )
{
- if(as<IRHLSLStructuredBufferTypeBase>(type))
- return true;
- if(as<IRByteAddressBufferTypeBase>(type))
- return true;
+ return isIllegalGLSLParameterType(type);
}
// For now, we will not treat any other parameters as
@@ -84,14 +84,22 @@ struct ResourceParameterSpecializationCondition : FunctionCallSpecializeConditio
}
};
-void specializeResourceParameters(
+bool specializeResourceParameters(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module)
{
+ bool result = false;
ResourceParameterSpecializationCondition condition;
condition.targetRequest = targetRequest;
- specializeFunctionCalls(compileRequest, targetRequest, module, &condition);
+ bool changed = true;
+ while (changed)
+ {
+ changed = specializeFunctionCalls(compileRequest, targetRequest, module, &condition);
+ simplifyIR(module);
+ result |= changed;
+ }
+ return result;
}
/// A pass to specialize resource-typed function outputs
@@ -111,8 +119,13 @@ struct ResourceOutputSpecializationPass
SharedIRBuilder sharedBuilder;
SharedIRBuilder* getSharedBuilder() { return &sharedBuilder; }
- void processModule()
+ // Functions that requires specialization but are currently unspecializable.
+ List<IRFunc*>* unspecializableFuncs;
+
+ bool processModule()
{
+ bool changed = false;
+
// We start by setting up the shared IR building state.
//
sharedBuilder.init(module);
@@ -127,11 +140,12 @@ struct ResourceOutputSpecializationPass
if(!func)
continue;
- processFunc(func);
+ changed |= processFunc(func);
}
+ return changed;
}
- void processFunc(IRFunc* oldFunc)
+ bool processFunc(IRFunc* oldFunc)
{
// We don't want to waste any effort on functions that don't merit
// specialization, so the first step is to identify if the function
@@ -141,7 +155,7 @@ struct ResourceOutputSpecializationPass
// the given function.
//
if(!shouldSpecializeFunc(oldFunc))
- return;
+ return false;
// It is possible that we have a function that we *should* specialize
// (based on its signature), but we *cannot* yet specialize it.
@@ -201,7 +215,9 @@ struct ResourceOutputSpecializationPass
// are sure we can optimize/simplify, so that the error
// messages can be front-end rather than back-end errors.
//
- return;
+ newFunc->removeAndDeallocate();
+ unspecializableFuncs->add(oldFunc);
+ return false;
}
// Specialization might have changed the signature of `newFunc`,
@@ -265,6 +281,7 @@ struct ResourceOutputSpecializationPass
{
specializeCallSite(oldCall, newFunc, funcInfo);
}
+ return true;
}
// With the overall flow of the pass described, we can now drill down
@@ -1096,16 +1113,13 @@ struct ResourceOutputSpecializationPass
// but transforming them so that the function signatures are changed makes
// the challenge more explicit and thus perhaps easier to tackle.
- // TODO: We probably need to update the two passes in this file so that they
- // work in an iterative fashion (combined with some SSA "cleanup" on function
- // bodies), because each optimization may open up opportunties for the other
- // to apply.
};
-void specializeResourceOutputs(
+bool specializeResourceOutputs(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
- IRModule* module)
+ IRModule* module,
+ List<IRFunc*>& unspecializableFuncs)
{
if(isD3DTarget(targetRequest) || isKhronosTarget(targetRequest))
{}
@@ -1118,14 +1132,102 @@ void specializeResourceOutputs(
// of conditional in a way that doesn't involve explicitly
// enumerating matching targets.
//
- return;
+ return false;
}
ResourceOutputSpecializationPass pass;
pass.compileRequest = compileRequest;
pass.targetRequest = targetRequest;
pass.module = module;
- pass.processModule();
+ pass.unspecializableFuncs = &unspecializableFuncs;
+ return pass.processModule();
+}
+
+bool specializeResourceUsage(
+ BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* irModule)
+{
+ bool result = false;
+ // We apply two kinds of specialization to clean up resource value usage:
+ //
+ // * Specalize call sites based on the actual resources
+ // that a called function will return/output.
+ //
+ // * Specialize called functions based on the actual resources
+ // passed as input at specific call sites.
+ //
+ // We need to run the two passes in an iterative fashion (combined with IR
+ // simplification passes), because each optimization may open up opportunties
+ // for the other to apply.
+ //
+ for (;;)
+ {
+ bool changed = true;
+ List<IRFunc*> unspecializableFuncs;
+ while (changed)
+ {
+ changed = false;
+ unspecializableFuncs.clear();
+ // Because the legalization may depend on what target
+ // we are compiling for (certain things might be okay
+ // for D3D targets that are not okay for Vulkan), we
+ // pass down the target request along with the IR.
+ //
+ changed |= specializeResourceOutputs(
+ compileRequest, targetRequest, irModule, unspecializableFuncs);
+ changed |= specializeResourceParameters(compileRequest, targetRequest, irModule);
+
+ // After specialization of function outputs, we may find that there
+ // are cases where opaque-typed local variables can now be eliminated
+ // and turned into SSA temporaries. Such optimization may enable
+ // the following passes to "see" and specialize more cases.
+ //
+ simplifyIR(irModule);
+ result |= changed;
+ }
+ if (unspecializableFuncs.getCount() == 0)
+ break;
+
+ // Inline unspecializable resource output functions and then continue trying.
+ for (auto func : unspecializableFuncs)
+ {
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ auto call = as<IRCall>(user);
+ if (!call)
+ continue;
+ if (call->getCallee() != func)
+ continue;
+ inlineCall(call);
+ }
+ }
+ simplifyIR(irModule);
+ }
+ return result;
+}
+
+bool isIllegalGLSLParameterType(IRType* type)
+{
+ if (as<IRUniformParameterGroupType>(type))
+ return true;
+ if (as<IRHLSLStructuredBufferTypeBase>(type))
+ return true;
+ if (as<IRByteAddressBufferTypeBase>(type))
+ return true;
+ if (as<IRGLSLImageType>(type))
+ return true;
+ if (auto texType = as<IRTextureType>(type))
+ {
+ switch (texType->getAccess())
+ {
+ case SLANG_RESOURCE_ACCESS_READ_WRITE:
+ case SLANG_RESOURCE_ACCESS_RASTER_ORDERED:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
}
} // namespace Slang
diff --git a/source/slang/slang-ir-specialize-resources.h b/source/slang/slang-ir-specialize-resources.h
index 62a2728bc..0f2fdc99e 100644
--- a/source/slang/slang-ir-specialize-resources.h
+++ b/source/slang/slang-ir-specialize-resources.h
@@ -6,6 +6,7 @@ namespace Slang
class BackEndCompileRequest;
class TargetRequest;
struct IRModule;
+ struct IRType;
/// Specialize calls to functions with resource-type parameters.
///
@@ -17,13 +18,20 @@ namespace Slang
/// those resource parameters (and instead, e.g, refers to the
/// global shader parameters directly).
///
- void specializeResourceParameters(
+ bool specializeResourceParameters(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module);
- void specializeResourceOutputs(
+ bool specializeResourceOutputs(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module);
+
+ /// Combined iterative passes of `specializeResourceParameters` and `specializeResourceOutputs`.
+ bool specializeResourceUsage(
+ BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* irModule);
+
+ bool isIllegalGLSLParameterType(IRType* type);
+
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
new file mode 100644
index 000000000..fcc6dc4ae
--- /dev/null
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -0,0 +1,36 @@
+// slang-ir-ssa-simplification.cpp
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir.h"
+#include "slang-ir-ssa.h"
+#include "slang-ir-sccp.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-simplify-cfg.h"
+
+namespace Slang
+{
+ struct IRModule;
+
+ // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
+ // until no more changes are possible.
+ void simplifyIR(IRModule* module)
+ {
+ bool changed = true;
+ const int kMaxIterations = 8;
+ int iterationCounter = 0;
+ while (changed && iterationCounter < kMaxIterations)
+ {
+ changed = false;
+ changed |= applySparseConditionalConstantPropagation(module);
+ changed |= simplifyCFG(module);
+
+ // Note: we disregard the `changed` state from dead code elimination pass since
+ // SCCP pass could be generating temporarily evaluated constant values and never actually use them.
+ // DCE will always remove those nearly generated consts and always returns true here.
+ eliminateDeadCode(module);
+
+ changed |= constructSSA(module);
+
+ iterationCounter++;
+ }
+ }
+}
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
new file mode 100644
index 000000000..19a39e8d4
--- /dev/null
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -0,0 +1,11 @@
+// slang-ir-ssa-simplification.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
+ // until no more changes are possible.
+ void simplifyIR(IRModule* module);
+}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 1804664fb..797fcb25c 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -1039,7 +1039,7 @@ static void breakCriticalEdges(
}
// Construct SSA form for a global value with code
-void constructSSA(ConstructSSAContext* context)
+bool constructSSA(ConstructSSAContext* context)
{
// First, detect and and break any critical edges in the CFG,
// because our representation of SSA form doesn't allow for them.
@@ -1052,7 +1052,7 @@ void constructSSA(ConstructSSAContext* context)
// If none of the variables are promote-able,
// then we can exit without making any changes
if (context->promotableVars.getCount() == 0)
- return;
+ return false;
// We are going to walk the blocks in order,
// and try to process each, by replacing loads
@@ -1187,10 +1187,12 @@ void constructSSA(ConstructSSAContext* context)
{
var->removeAndDeallocate();
}
+
+ return true;
}
// Construct SSA form for a global value with code
-void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal)
+bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal)
{
ConstructSSAContext context;
context.globalVal = globalVal;
@@ -1200,28 +1202,31 @@ void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal)
context.builder.init(context.sharedBuilder);
context.builder.setInsertInto(module);
- constructSSA(&context);
+ return constructSSA(&context);
}
-void constructSSA(IRModule* module, IRInst* globalVal)
+bool constructSSA(IRModule* module, IRInst* globalVal)
{
switch (globalVal->getOp())
{
case kIROp_Func:
case kIROp_GlobalVar:
- constructSSA(module, (IRGlobalValueWithCode*)globalVal);
+ return constructSSA(module, (IRGlobalValueWithCode*)globalVal);
default:
break;
}
+ return false;
}
-void constructSSA(IRModule* module)
+bool constructSSA(IRModule* module)
{
+ bool changed = false;
for(auto ii : module->getGlobalInsts())
{
- constructSSA(module, ii);
+ changed |= constructSSA(module, ii);
}
+ return changed;
}
}
diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h
index 635810c08..b327802a1 100644
--- a/source/slang/slang-ir-ssa.h
+++ b/source/slang/slang-ir-ssa.h
@@ -4,6 +4,7 @@
namespace Slang
{
struct IRModule;
-
- void constructSSA(IRModule* module);
+ struct IRGlobalValueWithCode;
+ bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal);
+ bool constructSSA(IRModule* module);
}
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index ffadf33cf..e8da87187 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -196,6 +196,50 @@ namespace Slang
}
}
+ void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code)
+ {
+ HashSet<IRBlock*> blocks;
+ for (auto block : code->getBlocks())
+ blocks.Add(block);
+ auto validateBranchTarget = [&](IRInst* inst, IRBlock* target)
+ {
+ validate(
+ context,
+ blocks.Contains(target),
+ inst,
+ "branch inst must have a valid target block that is defined within the same "
+ "scope.");
+ };
+ for (auto block : code->getBlocks())
+ {
+ auto terminator = block->getTerminator();
+ validate(context, terminator, block, "block must have valid terminator inst.");
+ switch (terminator->getOp())
+ {
+ case kIROp_conditionalBranch:
+ validateBranchTarget(
+ terminator, as<IRConditionalBranch>(terminator)->getTrueBlock());
+ validateBranchTarget(
+ terminator, as<IRConditionalBranch>(terminator)->getFalseBlock());
+ break;
+ case kIROp_loop:
+ case kIROp_unconditionalBranch:
+ validateBranchTarget(terminator, as<IRUnconditionalBranch>(terminator)->getTargetBlock());
+ break;
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(terminator);
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ validateBranchTarget(switchInst, switchInst->getCaseLabel(i));
+ }
+ validateBranchTarget(switchInst, switchInst->getDefaultLabel());
+ validateBranchTarget(switchInst, switchInst->getBreakLabel());
+ }
+ }
+ }
+ }
+
void validateIRInst(
IRValidateContext* context,
IRInst* inst)
@@ -207,6 +251,11 @@ namespace Slang
// If `inst` is itself a parent instruction, then we need to recursively
// validate its children.
validateIRInstChildren(context, inst);
+
+ if (auto code = as<IRGlobalValueWithCode>(inst))
+ {
+ validateCodeBody(context, code);
+ }
}
void validateIRModule(IRModule* module, DiagnosticSink* sink)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f7ebfdb64..721488f82 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3541,6 +3541,21 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitImageLoad(IRType* type, IRInst* image, IRInst* coord)
+ {
+ auto inst = createInst<IRImageLoad>(this, kIROp_ImageLoad, type, image, coord);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitImageStore(IRType* type, IRInst* image, IRInst* coord, IRInst* value)
+ {
+ IRInst* args[] = {image, coord, value};
+ auto inst = createInst<IRImageStore>(this, kIROp_ImageStore, type, 3, args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitFieldExtract(
IRType* type,
IRInst* base,
@@ -5730,6 +5745,7 @@ namespace Slang
case kIROp_makeArray:
case kIROp_makeStruct:
case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads
+ case kIROp_ImageSubscript:
case kIROp_FieldExtract:
case kIROp_FieldAddress:
case kIROp_getElement:
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index fe32c93b1..e59573cef 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -773,7 +773,7 @@ struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getVaryingInputRules() override;
virtual LayoutRulesImpl* getVaryingOutputRules() override;
virtual LayoutRulesImpl* getSpecializationConstantRules() override;
- virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override;
virtual LayoutRulesImpl* getParameterBlockRules() override;
LayoutRulesImpl* getRayPayloadParameterRules() override;
@@ -794,7 +794,7 @@ struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getVaryingInputRules() override;
virtual LayoutRulesImpl* getVaryingOutputRules() override;
virtual LayoutRulesImpl* getSpecializationConstantRules() override;
- virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override;
virtual LayoutRulesImpl* getParameterBlockRules() override;
LayoutRulesImpl* getRayPayloadParameterRules() override;
@@ -815,7 +815,7 @@ struct CPULayoutRulesFamilyImpl : LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getVaryingInputRules() override;
virtual LayoutRulesImpl* getVaryingOutputRules() override;
virtual LayoutRulesImpl* getSpecializationConstantRules() override;
- virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override;
virtual LayoutRulesImpl* getParameterBlockRules() override;
LayoutRulesImpl* getRayPayloadParameterRules() override;
@@ -835,7 +835,7 @@ struct CUDALayoutRulesFamilyImpl : LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getVaryingInputRules() override;
virtual LayoutRulesImpl* getVaryingOutputRules() override;
virtual LayoutRulesImpl* getSpecializationConstantRules() override;
- virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) override;
virtual LayoutRulesImpl* getParameterBlockRules() override;
LayoutRulesImpl* getRayPayloadParameterRules() override;
@@ -1140,8 +1140,10 @@ LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules()
return &kGLSLSpecializationConstantLayoutRulesImpl_;
}
-LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules()
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest* request)
{
+ if (request->getForceGLSLScalarBufferLayout())
+ return &kHLSLStructuredBufferLayoutRulesImpl_;
return &kStd430LayoutRulesImpl_;
}
@@ -1219,7 +1221,7 @@ LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules()
return nullptr;
}
-LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules()
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*)
{
return nullptr;
}
@@ -1273,7 +1275,7 @@ LayoutRulesImpl* CPULayoutRulesFamilyImpl::getSpecializationConstantRules()
{
return nullptr;
}
-LayoutRulesImpl* CPULayoutRulesFamilyImpl::getShaderStorageBufferRules()
+LayoutRulesImpl* CPULayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*)
{
return nullptr;
}
@@ -1339,7 +1341,7 @@ LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getSpecializationConstantRules()
{
return nullptr;
}
-LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getShaderStorageBufferRules()
+LayoutRulesImpl* CUDALayoutRulesFamilyImpl::getShaderStorageBufferRules(TargetRequest*)
{
return nullptr;
}
@@ -2538,7 +2540,8 @@ static RefPtr<TypeLayout> _createParameterGroupTypeLayout(
LayoutRulesImpl* getParameterBufferElementTypeLayoutRules(
ParameterGroupType* parameterGroupType,
- LayoutRulesImpl* rules)
+ LayoutRulesImpl* rules,
+ TargetRequest* targetRequest)
{
if( as<ConstantBufferType>(parameterGroupType) )
{
@@ -2558,7 +2561,7 @@ LayoutRulesImpl* getParameterBufferElementTypeLayoutRules(
}
else if( as<GLSLShaderStorageBufferType>(parameterGroupType) )
{
- return rules->getLayoutRulesFamily()->getShaderStorageBufferRules();
+ return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(targetRequest);
}
else if (as<ParameterBlockType>(parameterGroupType))
{
@@ -2580,7 +2583,8 @@ RefPtr<TypeLayout> createParameterGroupTypeLayout(
// Determine the layout rules to use for the contents of the block
auto elementTypeRules = getParameterBufferElementTypeLayoutRules(
parameterGroupType,
- parameterGroupRules);
+ parameterGroupRules,
+ context.targetReq);
auto elementType = parameterGroupType->elementType;
diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h
index 6e28b6c9d..219dfad0f 100644
--- a/source/slang/slang-type-layout.h
+++ b/source/slang/slang-type-layout.h
@@ -1001,7 +1001,7 @@ struct LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getVaryingInputRules() = 0;
virtual LayoutRulesImpl* getVaryingOutputRules() = 0;
virtual LayoutRulesImpl* getSpecializationConstantRules()= 0;
- virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules(TargetRequest* request) = 0;
virtual LayoutRulesImpl* getParameterBlockRules() = 0;
virtual LayoutRulesImpl* getRayPayloadParameterRules() = 0;
diff --git a/tests/bugs/sccp-switch-case-removal.slang b/tests/bugs/sccp-switch-case-removal.slang
new file mode 100644
index 000000000..cf8d36ef7
--- /dev/null
+++ b/tests/bugs/sccp-switch-case-removal.slang
@@ -0,0 +1,25 @@
+//TEST(compute,vulkan):COMPARE_COMPUTE_EX:-vk -slang -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+static const uint kConstant = 5;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float x = 1.0;
+ switch(kConstant)
+ {
+ case 5:
+ int tid = int(dispatchThreadID.x);
+ outputBuffer[tid] = 0;
+ break;
+ case 1:
+ ddx(x); // this should trigger glslang error if it doesn't get eliminated by sccp.
+ break;
+ default:
+ ddy(x); // this should trigger glslang error if it doesn't get eliminated by sccp.
+ break;
+ }
+} \ No newline at end of file
diff --git a/tests/bugs/sccp-switch-case-removal.slang.expected.txt b/tests/bugs/sccp-switch-case-removal.slang.expected.txt
new file mode 100644
index 000000000..ae25f7400
--- /dev/null
+++ b/tests/bugs/sccp-switch-case-removal.slang.expected.txt
@@ -0,0 +1,4 @@
+0
+0
+0
+0 \ No newline at end of file
diff --git a/tests/bugs/vk-image-write.slang b/tests/bugs/vk-image-write.slang
new file mode 100644
index 000000000..10141384c
--- /dev/null
+++ b/tests/bugs/vk-image-write.slang
@@ -0,0 +1,16 @@
+//TEST:CROSS_COMPILE: -profile ps_5_0 -entry main -target spirv-assembly
+
+// Ensure that we can lower to `imageStore` correctly.
+
+RWTexture2D<float4> t;
+
+void writeColor(float3 v)
+{
+ t[uint2(0,0)].xyz += v;
+}
+
+float4 main() : SV_Target
+{
+ writeColor(float3(1.0));
+ return float4(0);
+}
diff --git a/tests/bugs/vk-image-write.slang.glsl b/tests/bugs/vk-image-write.slang.glsl
new file mode 100644
index 000000000..7b9bcebf9
--- /dev/null
+++ b/tests/bugs/vk-image-write.slang.glsl
@@ -0,0 +1,40 @@
+//TEST_IGNORE_FILE:
+
+#version 450
+layout(row_major) uniform;
+layout(row_major) buffer;
+
+layout(rgba32f)
+layout(binding = 0)
+uniform image2D t_0;
+
+void writeColor_0(vec3 v_0)
+{
+ const uvec2 _S1 = uvec2(0U, 0U);
+
+ vec4 _S2 = (imageLoad((t_0), ivec2((_S1))));
+
+ vec3 _S3 = _S2.xyz + v_0;
+
+ ivec2 _S4 = ivec2(_S1);
+
+ vec4 _S5 = imageLoad(t_0,_S4);
+
+ vec4 _S6 = _S5;
+ _S6.xyz = _S3;
+
+ imageStore(t_0,_S4,_S6);
+ return;
+}
+
+
+layout(location = 0)
+out vec4 _S7;
+
+void main()
+{
+ writeColor_0(vec3(1.00000000000000000000));
+ _S7 = vec4(0);
+ return;
+}
+
diff --git a/tests/bugs/vk-structured-buffer-load.hlsl.glsl b/tests/bugs/vk-structured-buffer-load.hlsl.glsl
index 1c7ec8043..05c8de193 100644
--- a/tests/bugs/vk-structured-buffer-load.hlsl.glsl
+++ b/tests/bugs/vk-structured-buffer-load.hlsl.glsl
@@ -43,13 +43,12 @@ void main()
float HitT_0 = (gl_RayTmaxNV);
RayData.PackedHitInfoA_0.x = HitT_0;
- const uint use_rcp_0 = uint(0);
- float offsfloat_0 = ((gParamBlock_sbuf_0)._data[(int(uint(0)))]);
+ float offsfloat_0 = ((gParamBlock_sbuf_0)._data[(0)]);
- uint use_rcp_1 = use_rcp_0|uint(HitT_0 > 0.00000000000000000000);
+ uint use_rcp_0 = 0U | uint(HitT_0 > 0.00000000000000000000);
- if(bool(use_rcp_1))
+ if(bool(use_rcp_0))
{
float tmpA = rcp_0(offsfloat_0);
@@ -60,7 +59,7 @@ void main()
else
{
- if(use_rcp_1 > uint(0)&&offsfloat_0 == 0.00000000000000000000)
+ if(use_rcp_0 > 0U&&offsfloat_0 == 0.00000000000000000000)
{
float tmpB = (inversesqrt((offsfloat_0 + 1.00000000000000000000)));
diff --git a/tests/optimization/func-resource-result/func-resource-result-complex.slang b/tests/optimization/func-resource-result/func-resource-result-complex.slang
new file mode 100644
index 000000000..a5585ff4c
--- /dev/null
+++ b/tests/optimization/func-resource-result/func-resource-result-complex.slang
@@ -0,0 +1,45 @@
+// func-resource-result-simple.slang
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+// Test that a function that returns a resource type can be
+// compiled for targets that don't natively support resource
+// return values.
+
+//TEST_INPUT:set textures=[Texture2D(size=4, content = zero), Texture2D(size=4, content = one)]
+Texture2D textures[2];
+
+//TEST_INPUT:set sampler=Sampler
+SamplerState sampler;
+
+Texture2D getTex(int index)
+{
+ Texture2D result;
+ // Note: `index` here will need to be a compile time constant in order to generate
+ // valid GLSL. If constant folding and function inlining are all done correctly
+ // we should be able to compile this.
+ if (index == 0)
+ result = textures[1];
+ else
+ result = textures[0];
+ return result;
+}
+
+int test(int val)
+{
+ // Make sure index is a compile-time constant.
+ return getTex(int(0.0) + 1*2 < 5 ? 0 : 1).SampleLevel(sampler, float2(0,0), 0).x == 0.0 ? 0 : 1;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt b/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt
new file mode 100644
index 000000000..ef529012e
--- /dev/null
+++ b/tests/optimization/func-resource-result/func-resource-result-complex.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+1
+1
+1 \ No newline at end of file
diff --git a/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl b/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl
index 1454d493f..59fecd544 100644
--- a/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl
+++ b/tests/pipeline/rasterization/conservative-rasterization/inner-coverage.slang.glsl
@@ -1,14 +1,14 @@
+//TEST_IGNORE_FILE:
#version 450
-
#extension GL_NV_conservative_raster_underestimation : require
+layout(row_major) uniform;
+layout(row_major) buffer;
layout(location = 0)
out vec4 _S1;
void main()
{
- vec4 _S2;
- _S2 = vec4(uint(gl_FragFullyCoveredNV));
- _S1 = _S2;
+ _S1 = vec4(uint(gl_FragFullyCoveredNV));
return;
}
diff --git a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
index a12b9827b..1818b7789 100644
--- a/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
+++ b/tests/pipeline/rasterization/fragment-shader-interlock.slang.glsl
@@ -2,6 +2,8 @@
#version 450
#extension GL_ARB_fragment_shader_interlock : require
+layout(row_major) uniform;
+layout(row_major) buffer;
layout(rgba32f)
layout(binding = 0)
@@ -15,13 +17,9 @@ out vec4 _S2;
void main()
{
- vec4 _S3;
-
beginInvocationInterlockARB();
- vec4 _S4 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy)))));
-
- _S3 = _S4;
+ vec4 _S3 = (imageLoad((entryPointParams_texture_0), ivec2((uvec2(_S1.xy)))));
imageStore((entryPointParams_texture_0), ivec2((uvec2(_S1.xy))), _S3 + _S1);
endInvocationInterlockARB();
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
index c07e9b61c..1da5f4f8a 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.glsl
@@ -2,24 +2,18 @@
//TEST_IGNORE_FILE:
#version 450
-
#extension GL_NV_fragment_shader_barycentric : require
+layout(row_major) uniform;
+layout(row_major) buffer;
pervertexNV layout(location = 0)
-in vec4 _S1[3];
+in vec4 _S1[3];
layout(location = 0)
out vec4 _S2;
void main()
{
- vec4 _S3;
-
- _S3 = gl_BaryCoordNV.x * _S1[0]
- + gl_BaryCoordNV.y * _S1[1]
- + gl_BaryCoordNV.z * _S1[2];
-
- _S2 = _S3;
-
+ _S2 = gl_BaryCoordNV.x * ((_S1)[(0U)]) + gl_BaryCoordNV.y * ((_S1)[(1U)]) + gl_BaryCoordNV.z * ((_S1)[(2U)]);
return;
}
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index abe5a3b4d..7cfdf06c8 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -399,14 +399,20 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL setCurrentValue(uint64_t value) override
{
- VkSemaphoreSignalInfo signalInfo;
- signalInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO;
- signalInfo.pNext = NULL;
- signalInfo.semaphore = m_semaphore;
- signalInfo.value = 2;
+ uint64_t currentValue = 0;
+ SLANG_VK_RETURN_ON_FAIL(m_device->m_api.vkGetSemaphoreCounterValue(
+ m_device->m_api.m_device, m_semaphore, &currentValue));
+ if (currentValue < value)
+ {
+ VkSemaphoreSignalInfo signalInfo;
+ signalInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO;
+ signalInfo.pNext = NULL;
+ signalInfo.semaphore = m_semaphore;
+ signalInfo.value = value;
- SLANG_VK_RETURN_ON_FAIL(
- m_device->m_api.vkSignalSemaphore(m_device->m_api.m_device, &signalInfo));
+ SLANG_VK_RETURN_ON_FAIL(
+ m_device->m_api.vkSignalSemaphore(m_device->m_api.m_device, &signalInfo));
+ }
return SLANG_OK;
}
@@ -881,7 +887,7 @@ public:
ComPtr<IResourceView> depthStencilView;
uint32_t m_width;
uint32_t m_height;
- RefPtr<VKDevice> m_renderer;
+ BreakableReference<VKDevice> m_renderer;
VkClearValue m_clearValues[kMaxAttachments];
RefPtr<FramebufferLayoutImpl> m_layout;
public:
@@ -911,7 +917,7 @@ public:
m_height = getMipLevelSize(viewDesc->subresourceRange.mipLevel, size.height);
layerCount = viewDesc->subresourceRange.layerCount;
}
- else
+ else if (desc.renderTargetCount)
{
// If we don't have a depth attachment, then we must have at least
// one color attachment. Get frame dimension from there.
@@ -923,6 +929,12 @@ public:
m_height = getMipLevelSize(viewDesc->subresourceRange.mipLevel, size.height);
layerCount = viewDesc->subresourceRange.layerCount;
}
+ else
+ {
+ m_width = 1;
+ m_height = 1;
+ layerCount = 1;
+ }
if (layerCount == 0)
layerCount = 1;
// Create render pass.
@@ -2978,19 +2990,20 @@ public:
Index count = resourceViews.getCount();
for(Index i = 0; i < count; ++i)
{
- auto bufferView = static_cast<PlainBufferResourceViewImpl*>(resourceViews[i].Ptr());
-
VkDescriptorBufferInfo bufferInfo = {};
+ bufferInfo.range = VK_WHOLE_SIZE;
- if(bufferView)
- {
- bufferInfo.buffer = bufferView->m_buffer->m_buffer.m_buffer;
- bufferInfo.offset = bufferView->offset;
- bufferInfo.range = bufferView->size;
- }
- else
+ if (resourceViews[i])
{
- bufferInfo.range = VK_WHOLE_SIZE;
+ auto boundViewType = static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type;
+ if (boundViewType == ResourceViewImpl::ViewType::PlainBuffer)
+ {
+ auto bufferView =
+ static_cast<PlainBufferResourceViewImpl*>(resourceViews[i].Ptr());
+ bufferInfo.buffer = bufferView->m_buffer->m_buffer.m_buffer;
+ bufferInfo.offset = bufferView->offset;
+ bufferInfo.range = bufferView->size;
+ }
}
VkWriteDescriptorSet write = {};
@@ -3017,11 +3030,17 @@ public:
Index count = resourceViews.getCount();
for(Index i = 0; i < count; ++i)
{
- auto resourceView = static_cast<TexelBufferResourceViewImpl*>(resourceViews[i].Ptr());
VkBufferView bufferView = VK_NULL_HANDLE;
- if (resourceView)
+ if (resourceViews[i])
{
- bufferView = resourceView->m_view;
+ auto boundViewType =
+ static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type;
+ if (boundViewType == ResourceViewImpl::ViewType::TexelBuffer)
+ {
+ auto resourceView =
+ static_cast<TexelBufferResourceViewImpl*>(resourceViews[i].Ptr());
+ bufferView = resourceView->m_view;
+ }
}
VkWriteDescriptorSet write = {};
write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
@@ -3125,12 +3144,17 @@ public:
Index count = resourceViews.getCount();
for(Index i = 0; i < count; ++i)
{
- auto texture = static_cast<TextureResourceViewImpl*>(resourceViews[i].Ptr());
VkDescriptorImageInfo imageInfo = {};
- if (texture)
+ if (resourceViews[i])
{
- imageInfo.imageView = texture->m_view;
- imageInfo.imageLayout = texture->m_layout;
+ auto boundViewType =
+ static_cast<ResourceViewImpl*>(resourceViews[i].Ptr())->m_type;
+ if (boundViewType == ResourceViewImpl::ViewType::Texture)
+ {
+ auto texture = static_cast<TextureResourceViewImpl*>(resourceViews[i].Ptr());
+ imageInfo.imageView = texture->m_view;
+ imageInfo.imageLayout = texture->m_layout;
+ }
}
imageInfo.sampler = 0;
@@ -5081,6 +5105,8 @@ public:
void beginPass(IRenderPassLayout* renderPass, IFramebuffer* framebuffer)
{
FramebufferImpl* framebufferImpl = static_cast<FramebufferImpl*>(framebuffer);
+ if (!framebuffer)
+ framebufferImpl = this->m_device->m_emptyFramebuffer;
RenderPassLayoutImpl* renderPassImpl =
static_cast<RenderPassLayoutImpl*>(renderPass);
VkClearValue clearValues[kMaxAttachments] = {};
@@ -6526,6 +6552,8 @@ public:
ChunkedList<RefPtr<RefObject>, 1024> m_deviceObjectsWithPotentialBackReferences;
VkSampler m_defaultSampler;
+
+ RefPtr<FramebufferImpl> m_emptyFramebuffer;
};
void VKDevice::PipelineCommandEncoder::init(CommandBufferImpl* commandBuffer)
@@ -6704,7 +6732,9 @@ VKDevice::~VKDevice()
m_deviceQueue.destroy();
descriptorSetAllocator.close();
-
+
+ m_emptyFramebuffer = nullptr;
+
if (m_device != VK_NULL_HANDLE)
{
if (m_desc.existingDeviceHandles.handles[2].handleValue == 0)
@@ -7187,6 +7217,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
deviceExtensions.add(VK_KHR_RAY_QUERY_EXTENSION_NAME);
m_features.add("ray-query");
m_features.add("ray-tracing");
+ m_features.add("sm_6_6");
}
if (extendedFeatures.bufferDeviceAddressFeatures.bufferDeviceAddress)
@@ -7243,6 +7274,13 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
#endif
m_features.add("external-memory");
}
+ if (extensionNames.Contains(VK_EXT_CONSERVATIVE_RASTERIZATION_EXTENSION_NAME))
+ {
+ deviceExtensions.add(VK_EXT_CONSERVATIVE_RASTERIZATION_EXTENSION_NAME);
+ m_features.add("conservative-rasterization-3");
+ m_features.add("conservative-rasterization-2");
+ m_features.add("conservative-rasterization-1");
+ }
if (extensionNames.Contains(VK_EXT_DEBUG_REPORT_EXTENSION_NAME))
{
deviceExtensions.add(VK_EXT_DEBUG_REPORT_EXTENSION_NAME);
@@ -7358,6 +7396,21 @@ SlangResult VKDevice::initialize(const Desc& desc)
SLANG_VK_RETURN_ON_FAIL(m_api.vkCreateSampler(m_device, &samplerInfo, nullptr, &m_defaultSampler));
}
+ // Create empty frame buffer.
+ {
+ IFramebufferLayout::Desc layoutDesc = {};
+ layoutDesc.renderTargetCount = 0;
+ layoutDesc.depthStencil = nullptr;
+ ComPtr<IFramebufferLayout> layout;
+ SLANG_RETURN_ON_FAIL(createFramebufferLayout(layoutDesc, layout.writeRef()));
+ IFramebuffer::Desc desc = {};
+ desc.layout = layout;
+ ComPtr<IFramebuffer> framebuffer;
+ SLANG_RETURN_ON_FAIL(createFramebuffer(desc, framebuffer.writeRef()));
+ m_emptyFramebuffer = static_cast<FramebufferImpl*>(framebuffer.get());
+ m_emptyFramebuffer->m_renderer.breakStrongReference();
+ }
+
return SLANG_OK;
}
@@ -9194,6 +9247,16 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in
rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias;
rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported
+ VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {};
+ conservativeRasterInfo.sType =
+ VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT;
+ conservativeRasterInfo.conservativeRasterizationMode =
+ VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT;
+ if (desc.rasterizer.enableConservativeRasterization)
+ {
+ rasterizer.pNext = &conservativeRasterInfo;
+ }
+
auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout);
auto forcedSampleCount = rasterizerDesc.forcedSampleCount;
auto blendDesc = desc.blend;