summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2020-07-10 14:30:57 -0700
committerGitHub <noreply@github.com>2020-07-10 14:30:57 -0700
commit249f48dbb5e240c713661be969a6939ec57561e5 (patch)
tree6e7c3f117931e640b89a63c0083c1f14e5b81ea4 /source
parent6aad38a43394a60c02c6109199d427d88147e781 (diff)
CUDA/CPU varying compute inputs as IR pass (#1438)
The main change here is that the CPU and CUDA C++ emit paths now rely on an earlier IR pass to legalize the varying parameter list of a kernel and translate references to varying parameters with semantics like `SV_DispatchThreadID`. Doing so removes a lot of special-case logic from the emit passes. This work moves us even closer to being able to eliminate `KernelContext` from the CPU/CUDA emit logic, because it removes the issue of state related to varying inputs being stored in `KernelContext`. The new pass that handles the legalization is in `slang-ir-legalize-varying-params.cpp`, and it borrows heavily from the existing `slang-ir-glsl-legalize.cpp` pass. The new pass factors out the target-independent and target-dependent logic, so that both CPU and CUDA can share much of the same code despite having very different rules for how the system-value parameters are being provided. An eventual goal is to have the new pass also handle the GLSL case, but doing so requires copying even more logic out of the GLSL-specific pass, and doing so seemed like a step to far for what was meant to be a stepping-stone change as part of other work. As a result of the incomplete nature of the pass, certain cases don't work for compute shader inputs for CPU/CUDA (e.g., wrapping your varying inputs in a `struct` type parameter), but those were cases that also didn't work in the existing `emit`-based logic. One major consequence of this change is that the logic for emitting the various different functions that represent an entry point for our CPU back-end has been streamlined and simplified. The original logic had a fair bit of cleverness built in to try and avoid unnecessary math ops when computing the various IDs/indices, while the new logic is much more simplistic (the main dispatch function loops over threadgroups with a triply-nested `for` and then delegates to the group-level function loops over threads with its own nested `for`s). Longer term, it will be important to simplify the CPU functions we emit further, by eliminating things like the `_Thread` function that should never really be exposed to users (the minimum granularity of invoking a CPU compute kernel should be a single threadgroup). We may eventually decide to synthesize all of the extra code that is being generated in the `emit` pass as IR instead.
Diffstat (limited to 'source')
-rw-r--r--source/core/core.vcxproj.filters6
-rw-r--r--source/slang/slang-emit-cpp.cpp241
-rw-r--r--source/slang/slang-emit-cuda.cpp50
-rw-r--r--source/slang/slang-emit-cuda.h1
-rw-r--r--source/slang/slang-emit.cpp14
-rw-r--r--source/slang/slang-ir-insts.h27
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp1321
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h20
-rw-r--r--source/slang/slang-ir.cpp23
-rw-r--r--source/slang/slang.vcxproj4
-rw-r--r--source/slang/slang.vcxproj.filters6
11 files changed, 1501 insertions, 212 deletions
diff --git a/source/core/core.vcxproj.filters b/source/core/core.vcxproj.filters
index 3cb5ec8ec..4331745ba 100644
--- a/source/core/core.vcxproj.filters
+++ b/source/core/core.vcxproj.filters
@@ -93,6 +93,9 @@
<ClInclude Include="slang-shared-library.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="slang-short-list.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="slang-smart-pointer.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -138,9 +141,6 @@
<ClInclude Include="windows\slang-win-visual-studio-util.h">
<Filter>Header Files</Filter>
</ClInclude>
- <ClInclude Include="slang-short-list.h">
- <Filter>Header Files</Filter>
- </ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="slang-blob.cpp">
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 54c2257f2..b59611b38 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
// Deal with decorations that need
// to be emitted as attributes
- // We are going to ignore the parameters passed and just pass in the Context
+ // We start by emitting the result type and function name.
+ //
if (IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>())
{
+ // Note: we currently emit multiple functions to represent an entry point
+ // on CPU/CUDA, and these all bottleneck through the actual `IRFunc`
+ // here as a workhorse.
+ //
+ // Because the workhorse function is currently emitted as a member of
+ // `KernelContext`, and doesn't have the right signature to service
+ // general-purpose calls, it is being emitted with a `_` prefix.
+ //
StringBuilder prefixName;
prefixName << "_" << name;
emitType(resultType, prefixName);
- m_writer->emit("()\n");
}
else
{
emitType(resultType, name);
+ }
- m_writer->emit("(");
- auto firstParam = func->getFirstParam();
- for (auto pp = firstParam; pp; pp = pp->getNextParam())
- {
- // Ingore TypeType-typed parameters for now.
- // In the future we will pass around runtime type info
- // for TypeType parameters.
- if (as<IRTypeType>(pp->getFullType()))
- continue;
-
- if (pp != firstParam)
- m_writer->emit(", ");
+ // Next we emit the parameter list of the function.
+ //
+ m_writer->emit("(");
+ auto firstParam = func->getFirstParam();
+ for (auto pp = firstParam; pp; pp = pp->getNextParam())
+ {
+ // Ingore TypeType-typed parameters for now.
+ // In the future we will pass around runtime type info
+ // for TypeType parameters.
+ if (as<IRTypeType>(pp->getFullType()))
+ continue;
- emitSimpleFuncParamImpl(pp);
- }
- m_writer->emit(")");
+ if (pp != firstParam)
+ m_writer->emit(", ");
- emitSemantics(func);
+ emitSimpleFuncParamImpl(pp);
}
+ m_writer->emit(")");
+
+ emitSemantics(func);
// TODO: encode declaration vs. definition
if (isDefinition(func))
@@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
switch (inst->op)
{
- case kIROp_Param:
- {
- auto varLayout = getVarLayout(inst);
-
- if (varLayout)
- {
- if(auto systemValueSemantic = varLayout->findSystemValueSemanticAttr())
- {
- String semanticNameSpelling = systemValueSemantic->getName();
- semanticNameSpelling = semanticNameSpelling.toLower();
-
- if (semanticNameSpelling == "sv_dispatchthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
- m_writer->emit("dispatchThreadID");
- return;
- }
- else if (semanticNameSpelling == "sv_groupid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
- m_writer->emit("groupID");
- return;
- }
- else if (semanticNameSpelling == "sv_groupthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
- m_writer->emit("calcGroupThreadID()");
- return;
- }
- }
- }
- m_writer->emit(getName(inst));
- break;
- }
case kIROp_Var:
case kIROp_GlobalVar:
emitVarExpr(inst, outerPrec);
@@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
const auto& axis = axes[i];
builder.Clear();
const char elem[2] = { s_elemNames[axis.axis], 0 };
- builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n";
+ builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
builder.Clear();
- builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
+ builder << "threadInput.groupThreadID." << elem << " = " << elem << ";\n";
m_writer->emit(builder);
}
// just call at inner loop point
m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("(&threadInput);\n");
// Close all the loops
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
builder.Clear();
const char elem[2] = { s_elemNames[axis.axis], 0 };
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n";
- }
-
- builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n";
+ builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
- builder.Clear();
- builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
-
- if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
- {
- if (sizeAlongAxis[axis.axis] > 1)
- {
- builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n";
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n";
- }
- }
- else
- {
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " = " << elem << ";\n";
- }
- }
- }
-
- m_writer->emit(builder);
+ m_writer->emit("groupVaryingInput.startGroupID.");
+ m_writer->emit(elem);
+ m_writer->emit(" = ");
+ m_writer->emit(elem);
+ m_writer->emit(";\n");
}
// just call at inner loop point
- m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("_Group(&groupVaryingInput, entryPointParams, globalParams);\n");
// Close all the loops
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions)
}
}
+static bool isVaryingResourceKind(LayoutResourceKind kind)
+{
+ switch(kind)
+ {
+ default:
+ return false;
+
+ case LayoutResourceKind::VaryingInput:
+ case LayoutResourceKind::VaryingOutput:
+ return true;
+ }
+}
+
+static bool isVaryingParameter(IRTypeLayout* typeLayout)
+{
+ for(auto sizeAttr : typeLayout->getSizeAttrs())
+ {
+ if(!isVaryingResourceKind(sizeAttr->getResourceKind()))
+ return false;
+ }
+ return true;
+}
+
+static bool isVaryingParameter(IRVarLayout* varLayout)
+{
+ return isVaryingParameter(varLayout->getTypeLayout());
+}
+
void CPPSourceEmitter::_findShaderParams(
IRGlobalParam** outEntryPointParam,
IRGlobalParam** outGlobalParam)
@@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams(
if(!param)
continue;
+ if(auto layoutDecor = param->findDecoration<IRLayoutDecoration>())
+ {
+ if(auto varLayout = as<IRVarLayout>(layoutDecor->getLayout()))
+ {
+ if(isVaryingParameter(varLayout))
+ continue;
+ auto typeLayout = varLayout->getTypeLayout();
+ if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingOutput))
+ continue;
+ }
+ }
+
// Currently, the entry-point parameters
// are represented as a single parameter
// at the global scope, and the same is
@@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
m_writer->emit("struct KernelContext\n{\n");
m_writer->indent();
- m_writer->emit("uint3 dispatchThreadID;\n");
-
- //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- // Note not always set!
- m_writer->emit("uint3 groupID;\n");
- }
-
- //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("uint3 groupDispatchThreadID;\n");
-
- m_writer->emit("uint3 calcGroupThreadID() const \n{\n");
- m_writer->indent();
- // groupThreadID = dispatchThreadID - groupDispatchThreadID
- m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n");
- m_writer->emit("return v;\n");
- m_writer->dedent();
- m_writer->emit("}\n");
- }
-
-
if (globalParams)
{
emitGlobalInst(globalParams);
@@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute)
{
- // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
- // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
-
Int groupThreadSize[kThreadGroupAxisCount];
getComputeThreadGroupSize(func, groupThreadSize);
@@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral("ComputeThreadVaryingInput"));
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("context.groupDispatchThreadID = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- m_writer->emit("context.groupID = varyingInput->groupID;\n");
- }
-
- // Emit dispatchThreadID
- m_writer->emit("context.dispatchThreadID = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
-
m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("(varyingInput);\n");
_emitEntryPointDefinitionEnd(func);
}
@@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("context.groupDispatchThreadID = start;\n");
- }
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- m_writer->emit("context.groupID = varyingInput->startGroupID;\n");
- }
- m_writer->emit("context.dispatchThreadID = start;\n");
+ m_writer->emit("ComputeThreadVaryingInput threadInput = {};\n");
+ m_writer->emit("threadInput.groupID = varyingInput->startGroupID;\n");
_emitEntryPointGroup(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
@@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
{
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
- m_writer->emit("const uint3 end = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+ m_writer->emit("ComputeVaryingInput vi = *varyingInput;\n");
+ m_writer->emit("ComputeVaryingInput groupVaryingInput = {};\n");
_emitEntryPointGroupRange(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index d05c4edfc..c7dee9f9d 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -248,56 +248,6 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
SLANG_UNUSED(entryPointDecor);
}
-void CUDASourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec)
-{
- if (shouldFoldInstIntoUseSites(inst))
- {
- emitInstExpr(inst, outerPrec);
- return;
- }
-
- switch (inst->op)
- {
- case kIROp_Param:
- {
- auto varLayout = getVarLayout(inst);
- if (varLayout)
- {
- if (auto systemValueSemantic = varLayout->findSystemValueSemanticAttr())
- {
- String semanticNameSpelling = systemValueSemantic->getName();
- semanticNameSpelling = semanticNameSpelling.toLower();
-
- if (semanticNameSpelling == "sv_dispatchthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
- m_writer->emit("((blockIdx * blockDim) + threadIdx)");
-
- return;
- }
- else if (semanticNameSpelling == "sv_groupid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
- m_writer->emit("blockIdx");
- return;
- }
- else if (semanticNameSpelling == "sv_groupthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
- m_writer->emit("threadIdx");
- return;
- }
- }
- }
-
- break;
- }
- default: break;
- }
-
- Super::emitOperandImpl(inst, outerPrec);
-}
-
void CUDASourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec)
{
switch (specOp->op)
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index c0a5ac5dc..9afd34c4b 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -55,7 +55,6 @@ protected:
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
- virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec) SLANG_OVERRIDE;
virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE { SLANG_UNUSED(inst); m_writer->emit("__device__ "); }
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 7d2260597..d35acf4df 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-entry-point-uniforms.h"
#include "slang-ir-glsl-legalize.h"
#include "slang-ir-insts.h"
+#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-link.h"
#include "slang-ir-lower-generics.h"
#include "slang-ir-restructure.h"
@@ -590,6 +591,19 @@ Result linkAndOptimizeIR(
}
break;
+ case CodeGenTarget::CSource:
+ case CodeGenTarget::CPPSource:
+ {
+ legalizeEntryPointVaryingParamsForCPU(irModule, compileRequest->getSink());
+ }
+ break;
+
+ case CodeGenTarget::CUDASource:
+ {
+ legalizeEntryPointVaryingParamsForCUDA(irModule, compileRequest->getSink());
+ }
+ break;
+
default:
break;
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 553992406..d07a6d76e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -292,6 +292,8 @@ struct IRNumThreadsDecoration : IRDecoration
IRIntLit* getX() { return cast<IRIntLit>(getOperand(0)); }
IRIntLit* getY() { return cast<IRIntLit>(getOperand(1)); }
IRIntLit* getZ() { return cast<IRIntLit>(getOperand(2)); }
+
+ IRIntLit* getExtentAlongAxis(int axis) { return cast<IRIntLit>(getOperand(axis)); }
};
struct IREntryPointDecoration : IRDecoration
@@ -1557,6 +1559,19 @@ struct IRConstantKey
struct SharedIRBuilder
{
+ SharedIRBuilder()
+ {}
+
+ SharedIRBuilder(Session* session, IRModule* module)
+ : session(session)
+ , module(module)
+ {}
+
+ explicit SharedIRBuilder(IRModule* module)
+ : session(module->getSession())
+ , module(module)
+ {}
+
// The parent compilation session
Session* session;
Session* getSession()
@@ -1577,8 +1592,15 @@ struct IRBuilderSourceLocRAII;
struct IRBuilder
{
+ IRBuilder()
+ {}
+
+ IRBuilder(SharedIRBuilder* sharedBuilder)
+ : sharedBuilder(sharedBuilder)
+ {}
+
// Shared state for all IR builders working on the same module
- SharedIRBuilder* sharedBuilder;
+ SharedIRBuilder* sharedBuilder = nullptr;
Session* getSession()
{
@@ -2112,6 +2134,9 @@ struct IRBuilder
IRInst* emitBitAnd(IRType* type, IRInst* left, IRInst* right);
IRInst* emitBitNot(IRType* type, IRInst* value);
+ IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right);
+ IRInst* emitMul(IRType* type, IRInst* left, IRInst* right);
+
//
// Decorations
//
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
new file mode 100644
index 000000000..5772e79f9
--- /dev/null
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -0,0 +1,1321 @@
+// slang-ir-legalize-varying-params.cpp
+#include "slang-ir-legalize-varying-params.h"
+
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+// This pass implements logic to "legalize" the varying parameter
+// signature of an entry point.
+//
+// The traditional Slang/HLSL model is to have varying input parameters
+// be marked with "semantics" that can either mark them as user-defined
+// or system-value parameters. In addition the result (return value)
+// of the function can be marked, and effectively works like an `out`
+// parameter.
+//
+// Other targets have very different models for how varying parameters
+// are passed:
+//
+// * GLSL/SPIR-V declare user-defined varying input/output as global variables,
+// and system-defined varying parameters are available as magic built-in variables.
+//
+// * CUDA compute kernels expose varying inputs as magic built-in
+// variables like `threadIdx`.
+//
+// * Our CPU compilation path requires the caller to pass in a `ComputeThreadVaryingInput`
+// that specifies the values of the critical varying parameters for compute shaders.
+//
+// While these targets differ in how they prefer to represent varying parameters,
+// they share the common theme that they cannot work with the varying parameter
+// signature of functions as written in vanilla HLSL.
+//
+// This pass in this file is responsible for walking the parameters (and result)
+// of each entry point in an IR module and transforming them into a form that
+// is legal for each target. The shared logic deals with many aspects of the
+// HLSL/Slang model for varying parameters that need to be "desugared" for these
+// targets:
+//
+// * Slang allows either an `out` parameter or the result (return value) of the
+// entry point to be used interchangeably, so ensuring both cases are treated
+// the same is handled here.
+//
+// * Slang allows a varying parameter to use a `struct` or array type, so that
+// we need to recursively process elements and/or fields to find the leaf
+// varying parameters as they will be understood by other targets.
+//
+// * As an extension of the above, `struct`-type varying parameters in Slang
+// may mix user-defined and system-defined inputs/outputs.
+//
+// * Slang allows for `inout` varying parameters, which need to desugar into
+// distinct `in` and `out` parameters for targets like GLSL.
+
+
+#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
+ M(DispatchThreadID, SV_DispatchThreadID) \
+ M(GroupID, SV_GroupID) \
+ M(GroupThreadID, SV_GroupThreadID) \
+ M(GroupThreadIndex, SV_GroupIndex) \
+ /* end */
+
+ /// A known system-value semantic name that can be applied to a parameter
+ ///
+enum class SystemValueSemanticName
+{
+ None = 0,
+
+ // TODO: Should this enumeration be responsible for differentiating
+ // cases where the same semantic name string is allowed in multiple stages,
+ // or as both input/output in a single stage, and those different uses
+ // might result in different meanings? The alternative is to always
+ // pass around the semantic name, stage, and direction together so
+ // that code can tell those special cases apart.
+
+#define CASE(ID, NAME) ID,
+SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+#undef CASE
+
+ // TODO: There are many more system-value semantic names that we
+ // can/should handle here, but for now I've restricted this list
+ // to those that are necessary for translating compute shaders.
+};
+
+ /// A placeholder that represents the value of a legalized varying
+ /// parameter, for the purposes of substituting it into IR code.
+ ///
+struct LegalizedVaryingVal
+{
+public:
+ enum class Flavor
+ {
+ None, ///< No value (conceptually a literal of type `void`)
+
+ Value, ///< A simple value represented as a single `IRInst*`
+
+ Address, ///< A location in memory, identified by an address in an `IRInst*`
+ };
+
+ LegalizedVaryingVal()
+ {}
+
+ static LegalizedVaryingVal makeValue(IRInst* irInst)
+ {
+ return LegalizedVaryingVal(Flavor::Value, irInst);
+ }
+
+ static LegalizedVaryingVal makeAddress(IRInst* irInst)
+ {
+ return LegalizedVaryingVal(Flavor::Address, irInst);
+ }
+
+ Flavor getFlavor() const { return m_flavor; }
+
+ IRInst* getValue() const
+ {
+ SLANG_ASSERT(getFlavor() == Flavor::Value);
+ return m_irInst;
+ }
+
+ IRInst* getAddress() const
+ {
+ SLANG_ASSERT(getFlavor() == Flavor::Address);
+ return m_irInst;
+ }
+
+private:
+ LegalizedVaryingVal(Flavor flavor, IRInst* irInst)
+ : m_flavor(flavor)
+ , m_irInst(irInst)
+ {}
+
+ Flavor m_flavor = Flavor::None;
+ IRInst* m_irInst = nullptr;
+};
+
+ /// Materialize the value of `val` as a single IR instruction.
+ ///
+ /// Any IR code that is needed to materialize the value will be emitted to `builder`.
+IRInst* materialize(IRBuilder& builder, LegalizedVaryingVal const& val)
+{
+ switch( val.getFlavor() )
+ {
+ case LegalizedVaryingVal::Flavor::None:
+ return nullptr; // TODO: should use a `void` literal
+
+ case LegalizedVaryingVal::Flavor::Value:
+ return val.getValue();
+
+ case LegalizedVaryingVal::Flavor::Address:
+ return builder.emitLoad(val.getAddress());
+
+ default:
+ SLANG_UNEXPECTED("unimplemented");
+ break;
+ }
+}
+
+void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, LegalizedVaryingVal const& src)
+{
+ switch( dest.getFlavor() )
+ {
+ case LegalizedVaryingVal::Flavor::None:
+ break;
+
+ case LegalizedVaryingVal::Flavor::Address:
+ builder.emitStore(dest.getAddress(), materialize(builder, src));
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unimplemented");
+ break;
+ }
+}
+
+void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src)
+{
+ assign(builder, dest, LegalizedVaryingVal::makeValue(src));
+}
+
+ /// Context for the IR pass that legalizing entry-point
+ /// varying parameters for a target.
+ ///
+ /// This is an abstract base type that needs to be inherited
+ /// to implement the appropriate policy for a particular
+ /// compilation target.
+ ///
+struct EntryPointVaryingParamLegalizeContext
+{
+ // This pass will be invoked on an entire module, and will
+ // process all entry points in that module.
+ //
+public:
+ void processModule(IRModule* module, DiagnosticSink* sink)
+ {
+ m_module = module;
+ m_sink = sink;
+
+ // We will use multiple IR builders during the legalization
+ // process, to avoid having state changes on one builder
+ // affect other builders that might be in use.
+ //
+ // All of those builders will need to have a common
+ // shared builder to avoid unnecessary duplication of
+ // types/constants.
+ //
+ SharedIRBuilder sharedBuilderStorage;
+ sharedBuilderStorage.module = module;
+ sharedBuilderStorage.session = module->getSession();
+ m_sharedBuilder = &sharedBuilderStorage;
+
+ // Once the basic initialization is done, we will allow
+ // the subtype to implement its own initialization logic
+ // that should occur at the start of processing a module.
+ //
+ beginModuleImpl();
+
+ // We now search for entry-point definitions in the IR module.
+ // All entry points should appear at the global scope.
+ //
+ for(auto inst : module->getGlobalInsts())
+ {
+ // Entry points are IR functions.
+ //
+ auto func = as<IRFunc>(inst);
+ if(!func)
+ continue;
+
+ // Entry point functions must have the `[entryPoint]` decoration.
+ //
+ auto entryPointDecor = func->findDecoration<IREntryPointDecoration>();
+ if(!entryPointDecor)
+ continue;
+
+ // Once we find an entry point we process it immediately.
+ //
+ processEntryPoint(func, entryPointDecor);
+ }
+ }
+
+protected:
+
+ // As discussed in `processModule()`, a subtype can overide
+ // the `beginModuleImpl()` method to perform work that should
+ // only happen once per module that is processed.
+ //
+ virtual void beginModuleImpl()
+ {}
+
+ // We have both per-module and per-entry-point state that
+ // needs to be managed. The former is set up in `processModule()`,
+ // while the latter is used during `processEntryPoint`.
+ //
+ // Note: It would be possible in principle to remove some
+ // the statefullness from this pass by factoring the
+ // per-module and per-entry-point logic into distinct types,
+ // but then every target-specific implementation would
+ // need to comprise two types with complicated interdependencies.
+ // The current solution of a single type with statefullness
+ // seems easier to manage.
+
+ IRModule* m_module = nullptr;
+ DiagnosticSink* m_sink = nullptr;
+ SharedIRBuilder* m_sharedBuilder = nullptr;
+
+ IRFunc* m_entryPointFunc = nullptr;
+ IRBlock* m_firstBlock = nullptr;
+ IRInst* m_firstOrdinaryInst = nullptr;
+ Stage m_stage = Stage::Unknown;
+
+
+ void processEntryPoint(IRFunc* entryPointFunc, IREntryPointDecoration* entryPointDecor)
+ {
+ m_entryPointFunc = entryPointFunc;
+
+ // Before diving into the work of processing an entry point, we start by
+ // extracting a bunch of information about the entry point that will
+ // be useful to the downstream logic.
+ //
+ m_stage = entryPointDecor->getProfile().getStage();
+ m_firstBlock = entryPointFunc->getFirstBlock();
+ m_firstOrdinaryInst = m_firstBlock ? m_firstBlock->getFirstOrdinaryInst() : nullptr;
+
+ auto entryPointLayoutDecoration = entryPointFunc->findDecoration<IRLayoutDecoration>();
+ SLANG_ASSERT(entryPointLayoutDecoration);
+
+ auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecoration->getLayout());
+ SLANG_ASSERT(entryPointLayout);
+
+ // Note: Of particular importance is that we extract the first/last parameters
+ // of the function *before* we allow the subtype to perform per-entry-point
+ // setup operations. This ensures that if the subtype adds new parameters to
+ // the beginnign or end of the parameter list, those new parameters won't
+ // be processed.
+ //
+ IRParam* firstOriginalParam = m_firstBlock ? m_firstBlock->getFirstParam() : nullptr;
+ IRParam* lastOriginalParam = m_firstBlock ? m_firstBlock->getLastParam() : nullptr;
+
+ // We allow the subtype to perform whatever setup or code generation
+ // it wants to on a per-entry-point basis. In some cases this might
+ // inject code into the start of the function to provide the value
+ // of certain system-value parameters.
+ //
+ beginEntryPointImpl();
+
+ // We now proceed to the meat of the work.
+ //
+ // We start by considering the result of the entry point function
+ // if it is non-`void`.
+ //
+ auto resultType = entryPointFunc->getResultType();
+ if( !as<IRVoidType>(resultType) )
+ {
+ // We need to translate the existing function result type
+ // into zero or more varying parameters that are legal for
+ // the target. An entry point function result should be
+ // processed in a way that semantically matches an `out` parameter.
+ //
+ auto legalResult = createLegalVaryingVal(
+ resultType,
+ entryPointLayout->getResultLayout(),
+ LayoutResourceKind::VaryingOutput);
+
+ // Now that we have a representation of the value(s) that will
+ // be used to hold the entry-point result we need to transform
+ // any `returnVal(r)` instructions in the function body to
+ // instead assign `r` to `legalResult` and then `returnVoid`.
+ //
+ IRBuilder builder(m_sharedBuilder);
+ for( auto block : entryPointFunc->getBlocks() )
+ {
+ auto returnValInst = as<IRReturnVal>(block->getTerminator());
+ if(!returnValInst)
+ continue;
+
+ // We have a `returnVal` instruction that returns `resultVal`.
+ //
+ auto resultVal = returnValInst->getVal();
+
+ // To replace the existing `returnVal` instruction we will
+ // emit an assignment to the new legalized result (whether
+ // a global variable, `out` parameter, etc.) and a `returnVoid`.
+ //
+ builder.setInsertBefore(returnValInst);
+ assign(builder, legalResult, resultVal);
+ builder.emitReturn();
+
+ returnValInst->removeAndDeallocate();
+ }
+ }
+
+ // The parameters of the entry-point function will be processed in
+ // order to legalize them. We need to be careful when iterating
+ // over the parameters for a few reasons:
+ //
+ // * The subtype-specific setup logic could have introduce parameters
+ // at the beginning or end of the list. We defend against that by
+ // capturing `firstOriginalParam` and `lastOriginalParam` at the
+ // start of this function, and only iterating over that range.
+ //
+ // * Somehow we might have an entry point declaration but not a definition
+ // this is unlikely but defended against because `firstOriginalParam`
+ // and `lastOriginalParam` will be null in that case.
+ //
+ // * We will often be removing the parameters once we have legalized
+ // them, so we will modify the list while traversing it. We defend
+ // against this by capturing `nextParam` at the start of each iteration
+ // so that we move to the same parameter next, even if the current
+ // parameter got removed.
+ //
+ // * The subtype-specific logic for legalizing a specific parameter
+ // might decide to insert new parameters to replace it. This is another
+ // case of modifying the parameter list while iterating it, and we
+ // defend against it with `nextParam` just like we do for the problem
+ // of deletion.
+ //
+ IRParam* nextParam = nullptr;
+ for( auto param = firstOriginalParam; param; param = nextParam )
+ {
+ nextParam = param->getNextParam();
+
+ processParam(param);
+
+ if(param == lastOriginalParam)
+ break;
+ }
+ }
+
+ virtual void beginEntryPointImpl() {}
+
+ // The next level down is the per-parameter processing logic, which
+ // like the per-module and per-entry-point levels maintains its own
+ // state to simplify the code (avoiding lots of long parameters lists).
+
+ IRParam* m_param = nullptr;
+ IRVarLayout* m_paramLayout = nullptr;
+
+ void processParam(IRParam* param)
+ {
+ m_param = param;
+
+ // We expect and require all entry-point parameters to have layout
+ // information assocaited with them at this point.
+ //
+ auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
+ SLANG_ASSERT(paramLayoutDecoration);
+ m_paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
+ SLANG_ASSERT(m_paramLayout);
+
+ // TODO: We need to detect and skip parameters here that are not varying.
+
+ // TODO: The GLSL-specific variant of this pass has several
+ // special cases that handle entry-point parameters for things like
+ // GS output streams and input primitive topology.
+
+ // TODO: The GLSL-specific variant of this pass has special cases
+ // to deal with user-defined varying input to RT shaders, since
+ // these don't translate to globals in the same way as all other
+ // GLSL varying inputs.
+
+ // We need to start by detecting whether the parameter represents
+ // an `in` or an `out`/`inout` parameter, since that will determine
+ // the strategy we take.
+ //
+ auto paramType = param->getDataType();
+ if(auto inOutType = as<IRInOutType>(paramType))
+ {
+ processInOutParam(param, inOutType);
+ }
+ else if(auto outType = as<IROutType>(paramType))
+ {
+ processOutParam(param, outType);
+ }
+ else
+ {
+ processInParam(param, paramType);
+ }
+ }
+
+ // We anticipate that some targets may need to customize the handling
+ // of `out` and `inout` varying parameters, so we have `virtual` methods
+ // to handle those cases, which just delegate to a default implementation
+ // that provides baseline behavior that should in theory work for
+ // multiple targets.
+ //
+ virtual void processInOutParam(IRParam* param, IRInOutType* inOutType)
+ {
+ processMutableParam(param, inOutType);
+ }
+ virtual void processOutParam(IRParam* param, IROutType* inOutType)
+ {
+ processMutableParam(param, inOutType);
+ }
+
+ void processMutableParam(IRParam* param, IROutTypeBase* paramPtrType)
+ {
+ // The deafult handling of any mutable (`out` or `inout`) parameter
+ // will be to introduce a local variable of the corresponding
+ // type and to use that in place of the actual parameter during
+ // exeuction of the function.
+
+ // The replacement variable will have the type of the original
+ // parameter (the `T` in `Out<T>` or `InOut<T>`).
+ //
+ auto valueType = paramPtrType->getValueType();
+
+ // The replacement variable will be declared at the top of
+ // the function.
+ //
+ IRBuilder builder(m_sharedBuilder);
+ builder.setInsertBefore(m_firstOrdinaryInst);
+
+ auto localVar = builder.emitVar(valueType);
+ auto localVal = LegalizedVaryingVal::makeAddress(localVar);
+
+ if( auto inOutType = as<IRInOutType>(paramPtrType) )
+ {
+ // If the parameter was an `inout` and not just an `out`
+ // parameter, we will create one more more legal `in`
+ // parameters to represent the incoming value,
+ // and then assign from those legalized input(s)
+ // into our local variable at the start of the function.
+ //
+ auto inputVal = createLegalVaryingVal(
+ valueType,
+ m_paramLayout,
+ LayoutResourceKind::VaryingInput);
+ assign(builder, localVal, inputVal);
+ }
+
+ // Because the `out` or `inout` parameter is represented
+ // as a pointer, and our local variabel is also a pointer
+ // we can directly replace all uses of the original parameter
+ // with uses of the variable.
+ //
+ param->replaceUsesWith(localVar);
+
+ // For both `out` and `inout` parameters, we need to
+ // introduce one or more legalized `out` parameters
+ // to represent the outgoing value.
+ //
+ auto outputVal = createLegalVaryingVal(
+ valueType,
+ m_paramLayout,
+ LayoutResourceKind::VaryingOutput);
+
+ // In order to have changes to our local variable become
+ // visible in the legalized outputs, we need to assign
+ // from the local variable to the output as the last
+ // operation before any `return` instructions.
+ //
+ for( auto block : m_entryPointFunc->getBlocks() )
+ {
+ auto returnInst = as<IRReturn>(block->getTerminator());
+ if(!returnInst)
+ continue;
+
+ builder.setInsertBefore(returnInst);
+ assign(builder, outputVal, localVal);
+ }
+
+ // Once we are done replacing the original parameter,
+ // we can remove it from the function.
+ //
+ param->removeAndDeallocate();
+ }
+
+ void processInParam(IRParam* param, IRType* paramType)
+ {
+ // Legalizing an `in` parameter is easier than a mutable parameter.
+
+ // We start by creating one or more legalized `in` parameters
+ // to represent the incoming value.
+ //
+ auto legalVal = createLegalVaryingVal(
+ paramType,
+ m_paramLayout,
+ LayoutResourceKind::VaryingInput);
+
+ // Next, we "materialize" the legalized value to produce
+ // an `IRInst*` that represents it.
+ //
+ // Note: We materialize each input parameter once, at the top
+ // of the entry point. Making a copy in this way could
+ // introduce overhead if an input parameter is an array,
+ // since all indexing operations will now refer to a copy
+ // of the original array.
+ //
+ // TODO: We could in theory iterate over all uses of
+ // `param` and introduce a custom replacement for each.
+ // Such a replacement strategy could produce better code
+ // for things like indexing into varying arrays, but at the
+ // cost of more accesses to the input parameter data.
+ //
+ IRBuilder builder(m_sharedBuilder);
+ builder.setInsertBefore(m_firstOrdinaryInst);
+ IRInst* materialized = materialize(builder, legalVal);
+
+ // The materialized value can be used to completely
+ // replace the original parameter.
+ //
+ param->replaceUsesWith(materialized);
+ param->removeAndDeallocate();
+ }
+
+ // Depending on the "direction" of the parameter (`in`, `out`, `inout`)
+ // we may need to create one or legalized variables to represented it.
+ //
+ // We now turn our attention to the problem of creating a legalized
+ // value (wrapping zero or more variables/parameters) to represent
+ // a varying parameter of a given type for a specific direction:
+ // either input or output, but not both.
+ //
+ LegalizedVaryingVal createLegalVaryingVal(IRType* type, IRVarLayout* varLayout, LayoutResourceKind kind)
+ {
+ // The process we are going to use for creating legalized
+ // values is going to involve recursion over the `type`
+ // of the parameter, and there is a lot of state that
+ // we need to carry along the way.
+ //
+ // Rather than have our core recursive function have
+ // many parameters that need to be followed through
+ // all the recursive call sites, we are going to wrap
+ // the relevant data up in a `struct` and pass all
+ // the information down as a bundle.
+
+ auto typeLayout = varLayout->getTypeLayout();
+
+ VaryingParamInfo info;
+ info.type = type;
+ info.varLayout = varLayout;
+ info.typeLayout = typeLayout;
+ info.kind = kind;
+
+ return _createLegalVaryingVal(info);
+ }
+
+ // While recursing through the type of a varying parameter,
+ // we may need to make a recursive call on the element type
+ // of an array, while still tracking the fact that any
+ // leaf parameter we encounter needs to have the "outer
+ // array brackets" taken into account when giving it a type.
+ //
+ // For those purposes we have the `VaryingArrayDeclaratorInfo`
+ // type that keeps track of outer layers of array-ness
+ // for a parameter during our recursive walk.
+ //
+ // It is stored as a stack-allocated linked list, where the list flows
+ // up through the call stack.
+ //
+ struct VaryingArrayDeclaratorInfo
+ {
+ IRInst* elementCount = nullptr;
+ VaryingArrayDeclaratorInfo* next = nullptr;
+ };
+
+ // Here is the declaration of the bundled information we care
+ // about when declaring a varying parameter.
+ //
+ struct VaryingParamInfo
+ {
+ // We obviously care about the type of the parameter we
+ // need to legalize, as well as the layout of that type.
+ //
+ IRType* type = nullptr;
+ IRTypeLayout* typeLayout = nullptr;
+
+ // We also care about the variable layout information for
+ // the parameter, because that includes things like the semantic
+ // name/index, as well as any binding information that was
+ // computed (e.g., for the `location` of GLSL user-defined
+ // varying parameters).
+ //
+ // Note: the `varLayout` member may not represent a layout for
+ // a variable of the given `type`, because we might be peeling
+ // away layers of array-ness. Consider:
+ //
+ // int stuff[3] : STUFF
+ //
+ // When processing the parameter `stuff`, we start with `type`
+ // being `int[3]`, but then we will recurse on `int`. At that
+ // point the `varLayout` will still refer to `stuff` with its
+ // semantic of `STUFF`, but the `type` and `typeLayout` will
+ // refer to the `int` type.
+ //
+ IRVarLayout* varLayout = nullptr;
+
+ // As discussed above, sometimes `varLayout` will refer to an
+ // outer declaration of array type, while `type` and `typeLayout`
+ // refer to an element type (perhaps nested).
+ //
+ // The `arrayDeclarators` field stores a linked list representing
+ // outer layers of "array brackets" that surround the variable/field
+ // of `type`.
+ //
+ // If code decides to construct a leaf parameter based on `type`,
+ // then it will need to use these `arrayDeclarators` to wrap the
+ // type up to make it correct.
+ //
+ VaryingArrayDeclaratorInfo* arrayDeclarators = nullptr;
+
+ // In some cases the decision-making about how to lower a parameter
+ // will depend on the kind of varying parameter (input or output).
+ //
+ // TODO: We may find that there are cases where a target wants to
+ // support true `inout` varying parameters, and `LayoutResourceKind`
+ // cannot currently handle those.
+ //
+ LayoutResourceKind kind = LayoutResourceKind::None;
+
+ // When we arrive at a leaf parameter/field, we can identify whether
+ // it is a user-defined or system-value varying based on its semantic name.
+ //
+ // For convenience, target-specific subtypes only need to understand
+ // the enumerated `systemValueSemanticName` rather than needing to
+ // implement their own parsing of semantic name strings.
+ //
+ SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
+ };
+
+ LegalizedVaryingVal _createLegalVaryingVal(VaryingParamInfo const& info)
+ {
+ // By default, when we seek to creating a legalized value
+ // for a varying parameter, we will look at its type to
+ // decide what to do.
+ //
+ // For most basic types, we will immediately delegate to the
+ // base case (which will use target-specific logic).
+ //
+ // Note: The logic here will always fully scalarize the input
+ // type, gernerated multiple SOA declarations if the input
+ // was AOS. That choice is required for some cases in GLSL,
+ // and seems to be a reasonable default policy, but it could
+ // lead to some performance issues for shaders that rely
+ // on varying arrays.
+ //
+ // TODO: Consider whether some carefully designed early-out
+ // checks could avoid full scalarization when it is possible
+ // to avoid. Those early-out cases would probably need to
+ // align with the layout logic that is assigning `location`s
+ // to varying parameters.
+ //
+ auto type = info.type;
+ if (as<IRVoidType>(type))
+ {
+ return createSimpleLegalVaryingVal(info);
+ }
+ else if( as<IRBasicType>(type) )
+ {
+ return createSimpleLegalVaryingVal(info);
+ }
+ else if( as<IRVectorType>(type) )
+ {
+ return createSimpleLegalVaryingVal(info);
+ }
+ else if( as<IRMatrixType>(type) )
+ {
+ // Note: For now we are handling matrix types in a varying
+ // parameter list as if they were ordinary types like
+ // scalars and vectors. This works well enough for simple
+ // stuff, and is unlikely to see much use anyway.
+ //
+ // TODO: A more correct implementation will probably treat
+ // a matrix-type varying parameter as if it was syntax
+ // sugar for an array of rows.
+ //
+ return createSimpleLegalVaryingVal(info);
+ }
+ else if( auto arrayType = as<IRArrayType>(type) )
+ {
+ // A varying parameter of array type is an interesting beast,
+ // because depending on the element type of the array we
+ // might end up needing to generate multiple parameters in
+ // struct-of-arrays (SOA) fashion. This will notably
+ // come up in the case where the element type is a `struct`,
+ // with fields that mix both user-defined and system-value
+ // semantics.
+ //
+ auto elementType = arrayType->getElementType();
+ auto elementCount = arrayType->getElementCount();
+ auto arrayLayout = as<IRArrayTypeLayout>(info.typeLayout);
+ SLANG_ASSERT(arrayLayout);
+ auto elementTypeLayout = arrayLayout->getElementTypeLayout();
+
+ // We are going to recursively apply legalization to the
+ // element type of the array, but when doing so we will
+ // pass down information about the outer "array brackets"
+ // that this type represented.
+ //
+ VaryingArrayDeclaratorInfo arrayDeclarator;
+ arrayDeclarator.elementCount = elementCount;
+ arrayDeclarator.next = info.arrayDeclarators;
+
+ VaryingParamInfo elementInfo = info;
+ elementInfo.type = elementType;
+ elementInfo.typeLayout = elementTypeLayout;
+ elementInfo.arrayDeclarators = &arrayDeclarator;
+
+ return _createLegalVaryingVal(elementInfo);
+ }
+ else if( auto streamType = as<IRHLSLStreamOutputType>(type))
+ {
+ // Handling a geometry shader stream output type like
+ // `TriangleStream<T>` is similar to handling an array,
+ // but we do *not* pass down a "declarator" to note
+ // the wrapping type.
+ //
+ // This choice is appropriate for GLSL because geometry
+ // shader outputs are just declared as their per-vertex
+ // types and not wrapped in array or stream types.
+ //
+ // TODO: If we ever need to legalize geometry shaders for
+ // a target with different rules we might need to revisit
+ // this choice.
+ //
+ auto elementType = streamType->getElementType();
+ auto streamLayout = as<IRStreamOutputTypeLayout>(info.typeLayout);
+ SLANG_ASSERT(streamLayout);
+ auto elementTypeLayout = streamLayout->getElementTypeLayout();
+
+ VaryingParamInfo elementInfo = info;
+ elementInfo.type = elementType;
+ elementInfo.typeLayout = elementTypeLayout;
+
+ return _createLegalVaryingVal(elementInfo);
+ }
+ // Note: This file is currently missing the case for handling a varying `struct`.
+ // The relevant logic is present in `slang-ir-glsl-legalize`, but it would add
+ // a lot of complexity to this file to include it now.
+ //
+ // The main consequence of this choice is that this pass doesn't support varying
+ // parameters wrapped in `struct`s for the targets that require this pass
+ // (currently CPU and CUDA).
+ //
+ // TODO: Copy over the relevant logic from the GLSL-specific pass, as part of
+ // readying this file to handle the needs of all targets.
+ //
+ else
+ {
+ // When no special case matches, we assume the parameter
+ // has a simple type that we can handle directly.
+ //
+ return createSimpleLegalVaryingVal(info);
+ }
+ }
+
+ LegalizedVaryingVal createSimpleLegalVaryingVal(VaryingParamInfo const& info)
+ {
+ // At this point we've bottomed out in the type-based recursion
+ // and we have a leaf parameter of some simple type that should
+ // also have a single semantic name/index to work with.
+
+ // TODO: This seems like the right place to "wrap" the type back
+ // up in layers of array-ness based on the outer array brackets
+ // that were accumulated.
+
+ // Our first order of business will be to check whether the
+ // parameter represents a system-value parameter.
+ //
+ auto varLayout = info.varLayout;
+ auto semanticInst = varLayout->findSystemValueSemanticAttr();
+ if( semanticInst )
+ {
+ // We will compare the semantic name against our list of
+ // system-value semantics using conversion to lower-case
+ // to achieve a case-insensitive comparison (this is
+ // necessary because semantics in HLSL/Slang do not
+ // treat case as significant).
+ //
+ // TODO: It would be nice to have a case-insensitive
+ // comparsion operation on `UnownedStringSlice` to
+ // avoid all the `String`s we crete and thren throw
+ // away here.
+ //
+ String semanticNameSpelling = semanticInst->getName();
+ auto semanticName = semanticNameSpelling.toLower();
+
+ SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
+
+ #define CASE(ID, NAME) \
+ if(semanticName == String(#NAME).toLower()) \
+ { \
+ systemValueSemanticName = SystemValueSemanticName::ID; \
+ } \
+ else
+
+ SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+ #undef CASE
+ {
+ // no match
+ }
+
+ if( systemValueSemanticName != SystemValueSemanticName::None )
+ {
+ // If the leaf parameter has a system-value semantic, then
+ // we need to translate the system value in whatever way
+ // is appropraite for the target.
+ //
+ // TODO: The logic here is missing the behavior from the
+ // GLSL-specific pass that handles type conversion when
+ // a user-declared system-value parameter might not
+ // match the type that was expected exactly (e.g., they
+ // declare a `uint2` but the parameter is a `uint3`).
+ //
+ VaryingParamInfo systemValueParamInfo = info;
+ systemValueParamInfo.systemValueSemanticName = systemValueSemanticName;
+ return createLegalSystemVaryingValImpl(systemValueParamInfo);
+ }
+
+ // TODO: We should seemingly do something if the semantic name
+ // implies a system-value semantic (starts with `SV_`) but we
+ // didn't find a match.
+ //
+ // In practice, this is probably something that should be handled
+ // at the layout level (`slang-parameter-binding.cpp`), and the
+ // layout for a parameter should include the `SystemValueSemanticName`
+ // as an enumerated value rather than a string (so that downstream
+ // code doesn't have to get into the business of parsing it).
+ }
+
+ // If there was semantic applied to the parameter *or* the semantic
+ // wasn't recognized as a system-value semantic, then we need
+ // to do whatever target-specific logic is required to legalize
+ // a user-defined varying parameter.
+ //
+ return createLegalUserVaryingValImpl(info);
+ }
+
+ // The base type will provide default implementations of the logic
+ // for creating user-defined and system-value varyings, but in
+ // each case the default logic will simply diagnose an error.
+ //
+ // For targets that support either case, it is essential to
+ // override these methods with appropriate logic.
+
+ virtual LegalizedVaryingVal createLegalUserVaryingValImpl(VaryingParamInfo const& info)
+ {
+ SLANG_UNUSED(info);
+
+ m_sink->diagnose(m_param, Diagnostics::unimplemented, "this target doesn't support user-defined varying parameters");
+
+ return LegalizedVaryingVal();
+ }
+
+ virtual LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info)
+ {
+ return diagnoseUnsupportedSystemVal(info);
+ }
+
+ // As a utility for target-specific subtypes, we define a routine
+ // to diagnose the case of a system-value semantic that isn't
+ // understood by the target.
+
+ LegalizedVaryingVal diagnoseUnsupportedSystemVal(VaryingParamInfo const& info)
+ {
+ SLANG_UNUSED(info);
+
+ m_sink->diagnose(m_param, Diagnostics::unimplemented, "this target doesn't support this system-defined varying parameters");
+
+ return LegalizedVaryingVal();
+ }
+
+ // There are some cases of system-value inputs that can be derived
+ // from other inputs; notably compute shaders support `SV_DispatchThreadID`
+ // and `SV_GroupIndex` which can both be derived from the more primitive
+ // `SV_GroupID` and `SV_GroupThreadID`, together with the extents
+ // of the thread group (which are specified with `[numthreads(...)]`).
+ //
+ // As a utilty to target-specific subtypes, we define helpers for
+ // calculating the value of these derived system values from the
+ // more primitive ones.
+
+ /// Emit code to calculate `SV_DispatchThreadID`
+ IRInst* emitCalcDispatchThreadID(
+ IRBuilder& builder,
+ IRType* type,
+ IRInst* groupID,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+ {
+ // The dispatch thread ID can be computed as:
+ //
+ // dispatchThreadID = groupID*groupExtents + groupThreadID
+ //
+ // where `groupExtents` is the X,Y,Z extents of
+ // each thread group in threads (as given by
+ // `[numthreads(X,Y,Z)]`).
+
+ return builder.emitAdd(type,
+ builder.emitMul(type,
+ groupID,
+ groupExtents),
+ groupThreadID);
+ }
+
+ /// Emit code to calculate `SV_GroupIndex`
+ IRInst* emitCalcGroupThreadIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+ {
+ auto intType = builder.getIntType();
+ auto uintType = builder.getBasicType(BaseType::UInt);
+
+ // The group thread index can be computed as:
+ //
+ // groupThreadIndex = groupThreadID.x
+ // + groupThreadID.y*groupExtents.x
+ // + groupThreadID.z*groupExtents.x*groupExtents.z;
+ //
+ // or equivalently (with one less multiply):
+ //
+ // groupThreadIndex = (groupThreadID.z * groupExtents.y
+ // + groupThreadID.y) * groupExtents.x
+ // + groupThreadID.x;
+ //
+
+ // `offset = groupThreadID.z`
+ auto zAxis = builder.getIntValue(intType, 2);
+ IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
+
+ // `offset *= groupExtents.y`
+ // `offset += groupExtents.y`
+ auto yAxis = builder.getIntValue(intType, 1);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
+
+ // `offset *= groupExtents.x`
+ // `offset += groupExtents.x`
+ auto xAxis = builder.getIntValue(intType, 0);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
+
+ return offset;
+ }
+
+ // Several of the derived calcluations rely on having
+ // access to the "group extents" of a compute shader.
+ // That information is expected to be present on
+ // the entry point as a `[numthreads(...)]` attribute,
+ // and we define a convenience routine for accessing
+ // that information.
+
+ IRInst* emitCalcGroupExtents(
+ IRBuilder& builder,
+ IRVectorType* type)
+ {
+ if(auto numThreadsDecor = m_entryPointFunc->findDecoration<IRNumThreadsDecoration>())
+ {
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+
+ for( int axis = 0; axis < kAxisCount; axis++ )
+ {
+ auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
+ if(!litValue)
+ return nullptr;
+
+ groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
+ }
+
+ return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
+ }
+
+ // TODO: We may want to implement a backup option here,
+ // in case we ever want to support compute shaders with
+ // dynamic/flexible group size on targets that allow it.
+ //
+ SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
+ UNREACHABLE_RETURN(nullptr);
+ }
+};
+
+// With the target-independent core of the pass out of the way, we can
+// turn our attention to the target-specific subtypes that handle
+// translation of "leaf" varying parameters.
+
+struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalizeContext
+{
+ // CUDA compute kernels don't support user-defined varying
+ // input or output, and there are only a few system-value
+ // varying inputs to deal with.
+ //
+ // CUDA provides built-in global parameters `threadIdx`,
+ // `blockIdx`, and `blockDim` that we can make use of.
+ //
+ IRGlobalParam* threadIdxGlobalParam = nullptr;
+ IRGlobalParam* blockIdxGlobalParam = nullptr;
+ IRGlobalParam* blockDimGlobalParam = nullptr;
+
+ // All of our system values will be exposed with the
+ // `uint3` type, and we'll cache a pointer to that
+ // type to void looking it up repeatedly.
+ //
+ IRType* uint3Type = nullptr;
+
+ void beginModuleImpl() SLANG_OVERRIDE
+ {
+ // Because many of the varying parameters are defined
+ // as magic globals in CUDA, we can introduce their
+ // definitions once per module, instead of once per
+ // entry point.
+ //
+ IRBuilder builder(m_sharedBuilder);
+ builder.setInsertInto(m_module->getModuleInst());
+
+ // We begin by looking up the `uint` and `uint3` types.
+ //
+ auto uintType = builder.getBasicType(BaseType::UInt);
+ uint3Type = builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 3));
+
+ // Next we create IR type and variable layouts that
+ // we can use to mark the global parameters like
+ // `threadIdx` as varying parameters instead of
+ // uniform.
+ //
+ IRTypeLayout::Builder typeLayoutBuilder(&builder);
+ typeLayoutBuilder.addResourceUsage(LayoutResourceKind::VaryingInput, 1);
+ auto typeLayout = typeLayoutBuilder.build();
+
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
+ auto varLayout = varLayoutBuilder.build();
+
+ // Finaly, we construct global parameters to represent
+ // `threadIdx`, `blockIdx`, and `blockDim`.
+ //
+ // Each of these parameters is given a target-intrinsic
+ // decoration that ensures that (1) it will not get a declaration
+ // emitted in output code, and (2) it will be referenced
+ // by exactly the desired name (with no attempt to generate
+ // a unique name).
+
+ threadIdxGlobalParam = builder.createGlobalParam(uint3Type);
+ builder.addTargetIntrinsicDecoration(threadIdxGlobalParam, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("threadIdx"));
+ builder.addLayoutDecoration(threadIdxGlobalParam, varLayout);
+
+ blockIdxGlobalParam = builder.createGlobalParam(uint3Type);
+ builder.addTargetIntrinsicDecoration(blockIdxGlobalParam, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("blockIdx"));
+ builder.addLayoutDecoration(blockIdxGlobalParam, varLayout);
+
+ blockDimGlobalParam = builder.createGlobalParam(uint3Type);
+ builder.addTargetIntrinsicDecoration(blockDimGlobalParam, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("blockDim"));
+ builder.addLayoutDecoration(blockDimGlobalParam, varLayout);
+ }
+
+ // While CUDA provides many useful system values
+ // as built-in globals, it does not provide the
+ // equivalent of `SV_DispatchThreadID` or
+ // `SV_GroupIndex` as a built-in.
+ //
+ // We will instead synthesize those values on
+ // entry to each kernel.
+
+ IRInst* groupThreadIndex = nullptr;
+ IRInst* dispatchThreadID = nullptr;
+
+ void beginEntryPointImpl() SLANG_OVERRIDE
+ {
+ IRBuilder builder(m_sharedBuilder);
+ builder.setInsertBefore(m_firstOrdinaryInst);
+
+ // Note that we can use the built-in `blockDim`
+ // variable to determine the group extents,
+ // instead of inspecting the `[numthreads(...)]`
+ // attribute.
+ //
+ // This choice makes our output more idomatic
+ // as CUDA code, but might also cost a small
+ // amount of performance by not folding in
+ // the known constant values from `numthreads`.
+ //
+ // TODO: Add logic to use the values from
+ // `numthreads` if it is present, but to fall
+ // back to `blockDim` if not?
+
+ dispatchThreadID = emitCalcDispatchThreadID(
+ builder,
+ uint3Type,
+ blockIdxGlobalParam,
+ threadIdxGlobalParam,
+ blockDimGlobalParam);
+
+ groupThreadIndex = emitCalcGroupThreadIndex(
+ builder,
+ threadIdxGlobalParam,
+ blockDimGlobalParam);
+
+ // Note: we don't pay attention to whether the
+ // kernel actually makes use of either of these
+ // system values when we synthesize them.
+ //
+ // We can get away with this because we know
+ // that subsequent DCE passes will eliminate
+ // the computations if they aren't used.
+ //
+ // The main alternative would be to compute
+ // these values lazily, when they are first
+ // referenced. While that is possible, it
+ // requires more (and more subtle) code in this pass.
+ }
+
+ LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE
+ {
+ // Because all of the relevant values are either
+ // ambiently available in CUDA, or were computed
+ // eagerly in the entry block to the kernel
+ // function, we can easily return the right
+ // value to use for a system-value parameter.
+
+ switch( info.systemValueSemanticName )
+ {
+ case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(blockIdxGlobalParam);
+ case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(threadIdxGlobalParam);
+ case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
+
+ default:
+ return diagnoseUnsupportedSystemVal(info);
+ }
+ }
+};
+
+
+struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalizeContext
+{
+ // Slang translates compute shaders for CPU such that they always have an
+ // initial parameter that is a `ComputeThreadVaryingInput*`, and that
+ // type provides the essential parameters (`SV_GroupID` and `SV_GroupThreadID`
+ // as fields).
+ //
+ // Our legalization pass for CPU this begins with the per-module logic
+ // to synthesize an IR definition of that type and its fields, so that
+ // we can use it across entry points.
+
+ IRType* uintType = nullptr;
+ IRVectorType* uint3Type = nullptr;
+ IRType* uint3PtrType = nullptr;
+
+ IRStructType* varyingInputStructType = nullptr;
+ IRPtrType* varyingInputStructPtrType = nullptr;
+
+ IRStructKey* groupIDKey = nullptr;
+ IRStructKey* groupThreadIDKey = nullptr;
+
+ void beginModuleImpl() SLANG_OVERRIDE
+ {
+ IRBuilder builder(m_sharedBuilder);
+ builder.setInsertInto(m_module->getModuleInst());
+
+ uintType = builder.getBasicType(BaseType::UInt);
+ uint3Type = builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 3));
+ uint3PtrType = builder.getPtrType(uint3Type);
+
+ // As we construct the `ComputeThreadVaryingInput` type and its fields,
+ // we mark them all as target intrinsics, which means that their
+ // declarations will *not* be reproduced in the output code, instead
+ // coming from the "prelude" file that already defines this type.
+
+ varyingInputStructType = builder.createStructType();
+ varyingInputStructPtrType = builder.getPtrType(varyingInputStructType);
+
+ builder.addTargetIntrinsicDecoration(varyingInputStructType, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("ComputeThreadVaryingInput"));
+
+ groupIDKey = builder.createStructKey();
+ builder.addTargetIntrinsicDecoration(groupIDKey, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("groupID"));
+ builder.createStructField(varyingInputStructType, groupIDKey, uint3Type);
+
+ groupThreadIDKey = builder.createStructKey();
+ builder.addTargetIntrinsicDecoration(groupThreadIDKey, UnownedTerminatedStringSlice(""), UnownedTerminatedStringSlice("groupThreadID"));
+ builder.createStructField(varyingInputStructType, groupThreadIDKey, uint3Type);
+ }
+
+ // While the declaration of the `ComputeVaryingThreadInput` type
+ // can be shared across all entry points, each entry point must
+ // declare its own parameter to receive the varying parameters.
+ //
+ // We will extract the relevant fields from the `ComputeVaryingThreadInput`
+ // at the start of kernel execution (rather than repeatedly load them
+ // at each use site), and will also eagerly compute the derived
+ // values for `SV_DispatchThreadID` and `SV_GroupIndex`.
+
+ IRInst* groupID = nullptr;
+ IRInst* groupThreadID = nullptr;
+ IRInst* groupExtents = nullptr;
+ IRInst* dispatchThreadID = nullptr;
+ IRInst* groupThreadIndex = nullptr;
+
+ void beginEntryPointImpl() SLANG_OVERRIDE
+ {
+ groupID = nullptr;
+ groupThreadID = nullptr;
+ dispatchThreadID = nullptr;
+
+ IRBuilder builder(m_sharedBuilder);
+
+ auto varyingInputParam = builder.createParam(varyingInputStructPtrType);
+ varyingInputParam->insertBefore(m_firstBlock->getFirstChild());
+
+ builder.setInsertBefore(m_firstOrdinaryInst);
+
+ groupID = builder.emitLoad(
+ builder.emitFieldAddress(uint3PtrType, varyingInputParam, groupIDKey));
+
+ groupThreadID = builder.emitLoad(
+ builder.emitFieldAddress(uint3PtrType, varyingInputParam, groupThreadIDKey));
+
+ // Note: we need to rely on the presence of the `[numthreads(...)]` attribute
+ // to tell us the size of the compute thread group, which we will then use
+ // when computing the dispatch thread ID and group thread index.
+ //
+ // TODO: If we ever wanted to support flexible thread-group sizes for our
+ // CPU target, we'd need to change it so that the thread-group size can
+ // be passed in as part of `ComputeVaryingThreadInput`.
+ //
+ groupExtents = emitCalcGroupExtents(builder, uint3Type);
+
+ dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
+
+ groupThreadIndex = emitCalcGroupThreadIndex(builder, groupThreadID, groupExtents);
+ }
+
+ LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE
+ {
+ // Because all of the relvant system values were synthesized
+ // into the first block of the entry-point function, we can
+ // just return them wherever they are referenced.
+ //
+ // Note that any values that were synthesized but then are
+ // not referened will simply be eliminated as dead code
+ // in later passes.
+
+ switch( info.systemValueSemanticName )
+ {
+ case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(groupID);
+ case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(groupThreadID);
+ case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
+
+ default:
+ return diagnoseUnsupportedSystemVal(info);
+ }
+ }
+};
+
+void legalizeEntryPointVaryingParamsForCPU(
+ IRModule* module,
+ DiagnosticSink* sink)
+{
+ CPUEntryPointVaryingParamLegalizeContext context;
+ context.processModule(module, sink);
+}
+
+void legalizeEntryPointVaryingParamsForCUDA(
+ IRModule* module,
+ DiagnosticSink* sink)
+{
+ CUDAEntryPointVaryingParamLegalizeContext context;
+ context.processModule(module, sink);
+}
+
+}
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
new file mode 100644
index 000000000..ff93f38dd
--- /dev/null
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -0,0 +1,20 @@
+// slang-ir-legalize-varying-params.h
+#pragma once
+
+namespace Slang
+{
+
+class DiagnosticSink;
+
+struct IRFunc;
+struct IRModule;
+
+void legalizeEntryPointVaryingParamsForCPU(
+ IRModule* module,
+ DiagnosticSink* sink);
+
+void legalizeEntryPointVaryingParamsForCUDA(
+ IRModule* module,
+ DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 026df865d..b2ddc8ed3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3622,6 +3622,29 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitAdd(IRType* type, IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_Add,
+ type,
+ left,
+ right);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_Mul,
+ type,
+ left,
+ right);
+ addInst(inst);
+ return inst;
+ }
//
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index 95f98e4ec..d9c15fd23 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -233,6 +233,7 @@
<ClInclude Include="slang-ir-inst-defs.h" />
<ClInclude Include="slang-ir-insts.h" />
<ClInclude Include="slang-ir-layout.h" />
+ <ClInclude Include="slang-ir-legalize-varying-params.h" />
<ClInclude Include="slang-ir-link.h" />
<ClInclude Include="slang-ir-lower-generics.h" />
<ClInclude Include="slang-ir-missing-return.h" />
@@ -323,6 +324,7 @@
<ClCompile Include="slang-ir-inline.cpp" />
<ClCompile Include="slang-ir-layout.cpp" />
<ClCompile Include="slang-ir-legalize-types.cpp" />
+ <ClCompile Include="slang-ir-legalize-varying-params.cpp" />
<ClCompile Include="slang-ir-link.cpp" />
<ClCompile Include="slang-ir-lower-generics.cpp" />
<ClCompile Include="slang-ir-missing-return.cpp" />
@@ -418,4 +420,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
-</Project> \ No newline at end of file
+</Project> \ No newline at end of file
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index 561599a8f..19a571e0b 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -150,6 +150,9 @@
<ClInclude Include="slang-ir-layout.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="slang-ir-legalize-varying-params.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="slang-ir-link.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -416,6 +419,9 @@
<ClCompile Include="slang-ir-legalize-types.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="slang-ir-legalize-varying-params.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="slang-ir-link.cpp">
<Filter>Source Files</Filter>
</ClCompile>