summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-04-26 01:27:30 -0400
committerGitHub <noreply@github.com>2024-04-26 01:27:30 -0400
commite91bd3b0bdc50f66bfd302ff079c65fba5126474 (patch)
tree73c1ce23a93b895216031e0071e954fa4ff641ea
parentbc7231bbcf32819bed37012db3f01b34b5dd856a (diff)
WIP: Force Inline If RefType (#4005)
* Force Inline if reftype Fixes #3997. If we are using a refType, we now ForceInline. remarks: 1. Modifications were made in slang-ir-glsl-legalize to change how we translate GlobalParam proxy's into GlobalParam. a. We now handle the senario where a globalParam is used in multiple disjoint blocks (like 2 different functions). * try to figure out why CI fails but local works try to inline DispatchMesh, works locally, may fail on CI(?) * try another fix * add task tests + don't allow semi-early task-shader inline Task shader uses DispatchMesh which is a very big 'hack' where we check for the function name and modify the callees in very large ways. This function does inline, but it cannot inline early due to future mangling that this operation requires todo. This is reflected with the `[noRefInline]` modifier. It is a modifier so users may stop mandatory inlines with `__ref` parameter.
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/hlsl.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h7
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp64
-rw-r--r--source/slang/slang-ir-inline.cpp20
-rw-r--r--source/slang/slang-ir-inline.h4
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--tests/bugs/gh-3997.slang23
-rw-r--r--tests/language-feature/non-copyable-return.slang4
-rw-r--r--tests/pipeline/rasterization/mesh/task-simple.slang5
13 files changed, 108 insertions, 35 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index d8b5c38d3..f8012af1d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2636,4 +2636,7 @@ __attributeTarget(FuncDecl)
attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute;
__attributeTarget(FuncDecl)
-attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute; \ No newline at end of file
+attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;
+
+__attributeTarget(FuncDecl)
+attribute_syntax [noRefInline] : NoRefInlineAttribute; \ No newline at end of file
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 77a224b61..919796943 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -11580,8 +11580,11 @@ void SetMeshOutputCounts(uint vertexCount, uint primitiveCount)
//
// This function doesn't return.
//
+// This function cannot be inlined due to a legalization pass happening mid-way through processing
+// and later more processing happening to the function which requires eventual inlining.
[KnownBuiltin("DispatchMesh")]
[require(glsl_hlsl_spirv, meshshading)]
+[noRefInline]
void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, __ref P meshPayload)
{
__target_switch
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 1126e84ef..9a27b3c06 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1447,6 +1447,13 @@ class NoInlineAttribute : public Attribute
SLANG_AST_CLASS(NoInlineAttribute)
};
+ /// A `[noRefInline]` attribute represents a request to not force inline a
+ /// function specifically due to a refType parameter.
+class NoRefInlineAttribute : public Attribute
+{
+ SLANG_AST_CLASS(NoRefInlineAttribute)
+};
+
class DerivativeGroupQuadAttribute : public Attribute
{
SLANG_AST_CLASS(DerivativeGroupQuadAttribute)
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 5cfcde0c9..9369afbc5 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -492,7 +492,7 @@ Result linkAndOptimizeIR(
{
// We could fail because
// 1) It's not inlinable for some reason (for example if it's recursive)
- SLANG_RETURN_ON_FAIL(performStringInlining(irModule, sink));
+ SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink));
}
lowerReinterpret(targetProgram, irModule, sink);
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 54af7a746..7a970589e 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2725,19 +2725,18 @@ void legalizeEntryPointParameterForGLSL(
codeGenContext,
builder, paramType, paramLayout, LayoutResourceKind::VaryingInput, stage, pp);
- // Next we need to replace uses of the parameter with
- // references to the variable(s). We are going to do that
- // somewhat naively, by simply materializing the
- // variables at the start.
+ // we have a simple struct which represents all materialized GlobalParams, this
+ // struct will replace the no longer needed global variable which proxied as a
+ // GlobalParam.
IRInst* materialized = materializeValue(builder, globalValue);
+ // We next need to replace all uses of the proxy variable with the actual GlobalParam
pp->replaceUsesWith(materialized);
- // We finally need to replace all global variable references of a global
- // parameter with the actual global parameter for all function calls.
- // Global parameters are used with a OpStore to copy its data into a global
- // variable intermediary. We will follow the uses of a global parameter until
- // we find this OpStore, then we will replace uses of the intermediary object.
+ // GlobalParams use use a OpStore to copy its data into a global
+ // variable intermediary. We will follow the uses of this intermediary
+ // and replace all some of the uses (function calls and SPIRV Operands)
+ Dictionary<IRBlock*, IRInst*> blockToMaterialized;
IRBuilder replaceBuilder(materialized);
for (auto dec : pp->getDecorations())
{
@@ -2747,27 +2746,48 @@ void legalizeEntryPointParameterForGLSL(
auto globalVarType = cast<IRPtrTypeBase>(globalVar->getDataType())->getValueType();
auto key = dec->getOperand(1);
- // we will be replacing uses of `globalVarToReplace`, we need globalVarToReplaceNextUse
- // to catch the next use before it is removed from the list of uses
+ // we will be replacing uses of `globalVarToReplace`. We need globalVarToReplaceNextUse
+ // to catch the next use before it is removed from the list of uses.
IRUse* globalVarToReplaceNextUse;
for (auto globalVarUse = globalVar->firstUse; globalVarUse; globalVarUse = globalVarToReplaceNextUse)
{
globalVarToReplaceNextUse = globalVarUse->nextUse;
auto user = globalVarUse->getUser();
- if (user->getOp() != kIROp_Call)
- continue;
- for (Slang::UInt operandIndex = 0; operandIndex < user->getOperandCount();
- operandIndex++)
+ switch (user->getOp())
{
- auto operand = user->getOperand(operandIndex);
- auto operandUse = user->getOperands() + operandIndex;
- if (operand != globalVar)
- continue;
- replaceBuilder.setInsertBefore(user);
- auto field = replaceBuilder.emitFieldExtract(globalVarType, materialized, key);
- replaceBuilder.replaceOperand(operandUse, field);
+ case kIROp_SPIRVAsmOperandInst:
+ case kIROp_Call:
+ {
+ for (Slang::UInt operandIndex = 0; operandIndex < user->getOperandCount();
+ operandIndex++)
+ {
+ auto operand = user->getOperand(operandIndex);
+ auto operandUse = user->getOperands() + operandIndex;
+ if (operand != globalVar)
+ continue;
+
+ // a GlobalParam may be used across functions/blocks, we need to
+ // materialize at a minimum 1 struct per block.
+ auto callingBlock = getBlock(user);
+ bool found = blockToMaterialized.tryGetValue(callingBlock, materialized);
+ if (!found)
+ {
+ replaceBuilder.setInsertBefore(callingBlock->getFirstInst());
+ materialized = materializeValue(&replaceBuilder, globalValue);
+ blockToMaterialized.set(callingBlock, materialized);
+ }
+
+ replaceBuilder.setInsertBefore(user);
+ auto field = replaceBuilder.emitFieldExtract(globalVarType, materialized, key);
+ replaceBuilder.replaceOperand(operandUse, field);
+ break;
+ }
+ break;
+ }
+ default:
break;
}
+ continue;
}
}
}
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index a5538425f..1e8c1462f 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -709,15 +709,15 @@ void performMandatoryEarlyInlining(IRModule* module)
namespace { // anonymous
// Inlines calls that involve String types
-struct StringInliningPass : InliningPassBase
+struct TypeInliningPass : InliningPassBase
{
typedef InliningPassBase Super;
- StringInliningPass(IRModule* module)
+ TypeInliningPass(IRModule* module)
: Super(module)
{}
- bool doesTypeRequireInline(IRType* type)
+ bool doesTypeRequireInline(IRType* type, IRFunc* callee)
{
// TODO(JS):
// I guess there is a question here about what type around string requires
@@ -727,6 +727,12 @@ struct StringInliningPass : InliningPassBase
const auto op = type->getOp();
switch (op)
{
+ case kIROp_RefType:
+ {
+ if(callee->findDecoration<IRNoRefInlineDecoration>())
+ return false;
+ return true;
+ }
case kIROp_StringType:
case kIROp_NativeStringType:
{
@@ -742,7 +748,7 @@ struct StringInliningPass : InliningPassBase
{
auto callee = info.callee;
- if (doesTypeRequireInline(callee->getResultType()))
+ if (doesTypeRequireInline(callee->getResultType(), callee))
{
return true;
}
@@ -750,7 +756,7 @@ struct StringInliningPass : InliningPassBase
const auto count = Count(callee->getParamCount());
for (Index i = 0; i < count; ++i)
{
- if (doesTypeRequireInline(callee->getParamType(UInt(i))))
+ if (doesTypeRequireInline(callee->getParamType(UInt(i)), callee))
{
return true;
}
@@ -762,7 +768,7 @@ struct StringInliningPass : InliningPassBase
} // anonymous
-Result performStringInlining(IRModule* module, DiagnosticSink* sink)
+Result performTypeInlining(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
@@ -780,7 +786,7 @@ Result performStringInlining(IRModule* module, DiagnosticSink* sink)
//
while(true)
{
- StringInliningPass pass(module);
+ TypeInliningPass pass(module);
if (pass.considerAllCallSites())
{
// If there was a change try inlining again
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index a15887b00..c5b457a65 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -10,8 +10,8 @@ namespace Slang
struct IRGlobalValueWithCode;
class DiagnosticSink;
- /// Any call to a function that takes or returns a string parameter is inlined
- Result performStringInlining(IRModule* module, DiagnosticSink* sink);
+ /// Any call to a function that takes or returns a string/RefType parameter is inlined
+ Result performTypeInlining(IRModule* module, DiagnosticSink* sink);
/// Inline any call sites to functions marked `[unsafeForceInlineEarly]`
void performMandatoryEarlyInlining(IRModule* module);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 2aad3bf8e..f4954375d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -862,6 +862,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
/// Applie to an IR function and signals that inlining should not be performed unless unavoidable.
INST(NoInlineDecoration, noInline, 0, 0)
+ INST(NoRefInlineDecoration, noRefInline, 0, 0)
INST(DerivativeGroupQuadDecoration, DerivativeGroupQuad, 0, 0)
INST(DerivativeGroupLinearDecoration, DerivativeGroupLinear, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 0d66efb14..eae025c96 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -393,6 +393,7 @@ IR_SIMPLE_DECORATION(HLSLExportDecoration)
IR_SIMPLE_DECORATION(KeepAliveDecoration)
IR_SIMPLE_DECORATION(RequiresNVAPIDecoration)
IR_SIMPLE_DECORATION(NoInlineDecoration)
+IR_SIMPLE_DECORATION(NoRefInlineDecoration)
IR_SIMPLE_DECORATION(DerivativeGroupQuadDecoration)
IR_SIMPLE_DECORATION(DerivativeGroupLinearDecoration)
IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e4fc33e33..a78110a84 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9738,6 +9738,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
derivativeGroupLinearDecor = getBuilder()->addSimpleDecoration<IRDerivativeGroupLinearDecoration>(irFunc);
}
+ else if (auto noRefInlineAttribute = as<NoRefInlineAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRNoRefInlineDecoration>(irFunc);
+ }
else if (auto instanceAttr = as<InstanceAttribute>(modifier))
{
IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr);
diff --git a/tests/bugs/gh-3997.slang b/tests/bugs/gh-3997.slang
new file mode 100644
index 000000000..8c75da426
--- /dev/null
+++ b/tests/bugs/gh-3997.slang
@@ -0,0 +1,23 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -O0 -g
+
+//CHECK: OpEntryPoint
+
+float atomicAdd(__ref float value, float amount)
+{
+ __target_switch
+ {
+ case cpp:
+ __requirePrelude("#include <atomic>");
+ __intrinsic_asm "std::atomic_ref(*$0).fetch_add($1)";
+ case spirv:
+ return __atomicAdd(value, amount);
+ }
+}
+
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(4, 1, 1)]
+[shader("compute")]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) {
+ atomicAdd(outputBuffer[0], 1);
+} \ No newline at end of file
diff --git a/tests/language-feature/non-copyable-return.slang b/tests/language-feature/non-copyable-return.slang
index 20330c5f9..9b280b982 100644
--- a/tests/language-feature/non-copyable-return.slang
+++ b/tests/language-feature/non-copyable-return.slang
@@ -31,7 +31,7 @@ void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
{
let f = myFunc0(2.0);
// CHECK: 4.0
- // GLSL: void myFunc1_0(float y{{.*}}, spirv_by_reference MyType_0 {{.*}})
- // GLSL: void myFunc0_0(float x{{.*}}, spirv_by_reference MyType_0 {{.*}})
+ // GLSL: main(
+ // GLSL-NOT: MyType {{.*}} =
outputBuffer[0] = f.x;
}
diff --git a/tests/pipeline/rasterization/mesh/task-simple.slang b/tests/pipeline/rasterization/mesh/task-simple.slang
index dc3de82c0..2b2f3d186 100644
--- a/tests/pipeline/rasterization/mesh/task-simple.slang
+++ b/tests/pipeline/rasterization/mesh/task-simple.slang
@@ -1,6 +1,11 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -task -output-using-type -dx12 -use-dxil -profile sm_6_6 -render-features mesh-shader
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -task -output-using-type -vk -profile glsl_450+spirv_1_4 -render-features mesh-shader
//TEST:SIMPLE(filecheck=HLSL):-target hlsl -entry meshMain -stage mesh
+//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry taskMain -stage amplification
+
+// CHECK_SPV: OpEntryPoint
+// CHECK_SPV: TaskPayloadWorkgroupEXT
+
// To test a simple mesh shader, we'll generate 4 triangles, the vertices of
// each one will hold the triangle index and a value (the square). The fragment