summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-list.h12
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp48
-rw-r--r--source/slang/slang-ir-lower-out-parameters.cpp503
-rw-r--r--source/slang/slang-ir-lower-out-parameters.h12
-rw-r--r--source/slang/slang-ir-util.h19
-rw-r--r--source/slang/slang-ir.cpp3
7 files changed, 598 insertions, 1 deletions
diff --git a/source/core/slang-list.h b/source/core/slang-list.h
index 7c96e3844..597f1b9b6 100644
--- a/source/core/slang-list.h
+++ b/source/core/slang-list.h
@@ -192,6 +192,18 @@ public:
Index getCount() const { return m_count; }
Index getCapacity() const { return m_capacity; }
+ template<typename Predicate>
+ Index countIf(Predicate predicate) const
+ {
+ Index count = 0;
+ for (Index i = 0; i < getCount(); ++i)
+ {
+ if (predicate((*this)[i]))
+ count++;
+ }
+ return count;
+ }
+
const T* getBuffer() const { return m_buffer; }
T* getBuffer() { return m_buffer; }
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4f52bff5d..049772cb9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4975,7 +4975,7 @@ public:
IRSemanticDecoration* addSemanticDecoration(
IRInst* value,
UnownedStringSlice const& text,
- int index = 0)
+ IRIntegerValue index = 0)
{
return as<IRSemanticDecoration>(addDecoration(
value,
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index e5b3de90f..949962a5c 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -3,6 +3,8 @@
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
+#include "slang-ir-lower-out-parameters.h"
+#include "slang-ir-lower-tuple-types.h"
#include "slang-ir-util.h"
#include "slang-parameter-binding.h"
@@ -1880,6 +1882,7 @@ private:
auto structType = as<IRStructType>(param->getDataType());
builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
auto varLayout = findVarLayout(param);
+ SLANG_ASSERT(varLayout);
// If `param` already has a semantic, we don't want to hoist its fields out.
if (varLayout->findSystemValueSemanticAttr() != nullptr ||
@@ -2242,6 +2245,7 @@ private:
index++;
continue;
}
+ SLANG_ASSERT(typeLayout);
typeLayout->getFieldLayout(index);
auto fieldLayout = typeLayout->getFieldLayout(index);
if (auto offsetAttr = fieldLayout->findOffsetAttr(K))
@@ -4015,11 +4019,55 @@ private:
const UnownedStringSlice userSemanticName = toSlice("user_semantic");
};
+void legalizeVertexShaderOutputParamsForMetal(DiagnosticSink* sink, EntryPointInfo& entryPoint)
+{
+ const auto oldFunc = entryPoint.entryPointFunc;
+
+ // We can avoid this lowering if it's a simple scalar return as it's
+ // handled further down the pipeline
+ const bool hasOutParameters = anyOf(
+ oldFunc->getParams(),
+ [](auto param) { return as<IROutTypeBase>(param->getFullType()); });
+
+ auto returnType = oldFunc->getResultType();
+ if (!as<IRStructType>(returnType) && !hasOutParameters)
+ return;
+
+ const bool alwaysUseReturnStruct = true;
+ entryPoint.entryPointFunc = lowerOutParameters(oldFunc, sink, alwaysUseReturnStruct);
+
+ if (oldFunc == entryPoint.entryPointFunc)
+ return;
+
+ // Since this will no longer be the entry point function, remove those decorations
+ List<IRDecoration*> ds;
+ for (auto decor : oldFunc->getDecorations())
+ {
+ if (as<IRKeepAliveDecoration>(decor) || as<IREntryPointDecoration>(decor))
+ {
+ ds.add(decor);
+ }
+ }
+
+ for (auto decor : ds)
+ {
+ decor->removeFromParent();
+ }
+}
+
+
void legalizeEntryPointVaryingParamsForMetal(
IRModule* module,
DiagnosticSink* sink,
List<EntryPointInfo>& entryPoints)
{
+ for (auto& e : entryPoints)
+ {
+ if (e.entryPointDecor->getProfile().getStage() == Stage::Vertex)
+ {
+ legalizeVertexShaderOutputParamsForMetal(sink, e);
+ }
+ }
LegalizeMetalEntryPointContext context(module, sink);
context.legalizeEntryPoints(entryPoints);
}
diff --git a/source/slang/slang-ir-lower-out-parameters.cpp b/source/slang/slang-ir-lower-out-parameters.cpp
new file mode 100644
index 000000000..2eec66db5
--- /dev/null
+++ b/source/slang/slang-ir-lower-out-parameters.cpp
@@ -0,0 +1,503 @@
+#include "slang-ir-lower-out-parameters.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-inline.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+static IRSemanticDecoration* findSemanticDecoration(IRInst* inst)
+{
+ for (auto decor : inst->getDecorations())
+ {
+ if (auto semanticDecor = as<IRSemanticDecoration>(decor))
+ {
+ return semanticDecor;
+ }
+ }
+ return nullptr;
+}
+
+static void transferSemanticDecorations(IRInst* source, IRInst* target, IRBuilder& builder)
+{
+ for (auto decor : source->getDecorations())
+ {
+ if (auto semanticDecor = as<IRSemanticDecoration>(decor))
+ {
+ builder.addSemanticDecoration(
+ target,
+ semanticDecor->getSemanticName(),
+ semanticDecor->getSemanticIndex());
+ }
+ }
+}
+
+// Find the entry point layout information from a function
+static bool findEntryPointLayoutInfo(
+ IRFunc* func,
+ IREntryPointLayout*& outEntryPointLayout,
+ IRVarLayout*& outParamsLayout,
+ IRVarLayout*& outResultLayout)
+{
+ for (auto decor : func->getDecorations())
+ {
+ if (auto layoutDecor = as<IRLayoutDecoration>(decor))
+ {
+ if (auto entryLayout = as<IREntryPointLayout>(layoutDecor->getLayout()))
+ {
+ outEntryPointLayout = entryLayout;
+ if (entryLayout->getOperandCount() >= 2)
+ {
+ outParamsLayout = as<IRVarLayout>(entryLayout->getOperand(0));
+ outResultLayout = as<IRVarLayout>(entryLayout->getOperand(1));
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+static bool findReturnValueSemanticInfo(
+ IRFunc* func,
+ UnownedStringSlice& outSemanticName,
+ IRIntegerValue& outSemanticIndex)
+{
+ // Check for semantic on function itself
+ if (auto semanticDecor = findSemanticDecoration(func))
+ {
+ outSemanticName = semanticDecor->getSemanticName();
+ outSemanticIndex = semanticDecor->getSemanticIndex();
+ return true;
+ }
+
+ // Check for semantic in entry point layout
+ IREntryPointLayout* entryLayout = nullptr;
+ IRVarLayout* paramsLayout = nullptr;
+ IRVarLayout* resultLayout = nullptr;
+
+ if (findEntryPointLayoutInfo(func, entryLayout, paramsLayout, resultLayout))
+ {
+ for (UInt i = 0; i < resultLayout->getOperandCount(); i++)
+ {
+ auto operand = resultLayout->getOperand(i);
+ if (auto semanticAttr = as<IRSystemValueSemanticAttr>(operand))
+ {
+ outSemanticName = semanticAttr->getName();
+ outSemanticIndex = semanticAttr->getIndex();
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+// Structure to hold parameter information
+struct ParamInfo
+{
+ IRParam* origParam; // Original parameter
+ IRType* valueType; // Parameter value type (without out/inout wrapper)
+ bool isOut; // Is an out parameter
+ bool isInOut; // Is an inout parameter
+ IRParam* newParam; // New param once created, pure out params will always be null here
+ IRVar* outVar; // Out variable (nullptr for non-out params)
+ IRStructKey* outFieldKey; // Field key (for out params)
+};
+
+// Analyze parameters and collect information
+List<ParamInfo> collectParameterInfo(
+ IRFunc* func,
+ IRBuilder& builder,
+ List<IRStructKey*>& outKeys,
+ Dictionary<IRParam*, IRStructKey*>& paramToKeyMap)
+{
+ List<ParamInfo> paramInfos;
+
+ for (auto param = func->getFirstParam(); param; param = param->getNextParam())
+ {
+ ParamInfo info;
+ info.origParam = param;
+ info.newParam = nullptr;
+ info.outVar = nullptr;
+ info.outFieldKey = nullptr;
+
+ if (auto outType = as<IROutTypeBase>(param->getDataType()))
+ {
+ // Handle out/inout parameter
+ info.valueType = outType->getValueType();
+ info.isOut = true;
+ info.isInOut = (outType->getOp() == kIROp_InOutType);
+
+ // Create field key for out parameter
+ String fieldName = "param";
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ fieldName = String(nameHint->getName());
+
+ auto fieldKey = builder.createStructKey();
+ builder.addNameHintDecoration(fieldKey, UnownedStringSlice(fieldName.getBuffer()));
+
+ // Transfer semantic decorations
+ transferSemanticDecorations(param, fieldKey, builder);
+
+ // Store field key for layout
+ info.outFieldKey = fieldKey;
+ outKeys.add(fieldKey);
+ paramToKeyMap[param] = fieldKey;
+ }
+ else
+ {
+ // Regular parameter
+ info.valueType = param->getDataType();
+ info.isOut = false;
+ info.isInOut = false;
+ }
+
+ paramInfos.add(info);
+ }
+
+ return paramInfos;
+}
+
+// Create a result key for non-void return types
+static IRStructKey* createResultKey(IRFunc* func, IRBuilder& builder, List<IRStructKey*>& outKeys)
+{
+ if (as<IRVoidType>(func->getResultType()))
+ return nullptr;
+
+ IRStructKey* resultKey = builder.createStructKey();
+ builder.addNameHintDecoration(resultKey, UnownedStringSlice("result"));
+
+ // Transfer semantic decoration from function return value to struct key
+ UnownedStringSlice semanticName;
+ IRIntegerValue semanticIndex = 0;
+ if (findReturnValueSemanticInfo(func, semanticName, semanticIndex))
+ {
+ builder.addSemanticDecoration(resultKey, semanticName, semanticIndex);
+ }
+
+ outKeys.add(resultKey);
+ return resultKey;
+}
+
+// Determine if we need to transform the function
+static bool needsTransformation(const List<ParamInfo>& paramInfos, bool alwaysUseReturnStruct)
+{
+ if (alwaysUseReturnStruct)
+ return true;
+
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut && !info.isInOut)
+ return true;
+ }
+
+ return false;
+}
+
+// Create the return type for the new function
+static IRType* createReturnType(
+ IRFunc* func,
+ IRBuilder& builder,
+ IRStructKey* resultKey,
+ const List<IRStructKey*>& outKeys,
+ const List<ParamInfo>& paramInfos,
+ bool alwaysUseReturnStruct,
+ IRStructType*& returnStruct)
+{
+ // Determine if we need a struct return type
+ bool needsStructReturn = alwaysUseReturnStruct ||
+ (resultKey != nullptr && outKeys.getCount()) || outKeys.getCount() > 1;
+
+ if (needsStructReturn)
+ {
+ returnStruct = builder.createStructType();
+
+ // Create name for struct
+ StringBuilder nameBuilder;
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ nameBuilder << nameHint->getName() << "_Result";
+ else
+ nameBuilder << "Function_Result";
+
+ builder.addNameHintDecoration(
+ returnStruct,
+ UnownedStringSlice(nameBuilder.toString().getBuffer()));
+
+ // Create fields for the struct
+ if (resultKey)
+ {
+ builder.createStructField(returnStruct, resultKey, func->getResultType());
+ }
+
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut && info.outFieldKey)
+ {
+ builder.createStructField(returnStruct, info.outFieldKey, info.valueType);
+ }
+ }
+
+ return returnStruct;
+ }
+ else if (outKeys.getCount())
+ {
+ // Find the first out parameter's type
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut)
+ {
+ return info.valueType;
+ }
+ }
+ }
+
+ // Default case
+ return builder.getVoidType();
+}
+
+// Create parameters for the new function
+static void createNewParameters(
+ IRBuilder& builder,
+ List<ParamInfo>& paramInfos,
+ IRCloneEnv& cloneEnv,
+ Dictionary<IRParam*, IRParam*>& origToNewParamMap)
+{
+ for (auto& info : paramInfos)
+ {
+ if (!info.isOut || info.isInOut)
+ {
+ // Create parameter
+ auto newParam = builder.emitParam(info.valueType);
+
+ // Copy ALL decorations including layout
+ for (auto decor : info.origParam->getDecorations())
+ {
+ cloneDecoration(&cloneEnv, decor, newParam, builder.getModule());
+ }
+
+ info.newParam = newParam;
+ origToNewParamMap[info.origParam] = newParam;
+ }
+
+ if (info.isOut)
+ {
+ // Create out variable
+ auto var = builder.emitVar(info.valueType);
+ info.outVar = var;
+
+ // Initialize inout variables from parameters
+ if (info.isInOut && info.newParam)
+ {
+ builder.emitStore(var, info.newParam);
+ }
+ }
+ }
+}
+
+// Build the call to the original function
+static IRCall* buildOriginalFunctionCall(
+ IRFunc* func,
+ IRBuilder& builder,
+ const List<ParamInfo>& paramInfos)
+{
+ List<IRInst*> args;
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut)
+ {
+ args.add(info.outVar);
+ }
+ else
+ {
+ args.add(info.newParam);
+ }
+ }
+
+ // Call the original function
+ return builder.emitCallInst(func->getResultType(), func, args);
+}
+
+// Construct the return value for the new function
+static IRInst* constructReturnValue(
+ IRBuilder& builder,
+ IRCall* callResult,
+ IRStructKey* resultKey,
+ IRStructType* returnStruct,
+ const List<IRStructKey*>& outKeys,
+ const List<ParamInfo>& paramInfos)
+{
+ if (returnStruct)
+ {
+ // Collect field values in order
+ List<IRInst*> fieldValues;
+
+ // Add original return value if non-void
+ if (resultKey)
+ {
+ fieldValues.add(callResult);
+ }
+
+ // Add out parameter values
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut && info.outVar)
+ {
+ fieldValues.add(builder.emitLoad(info.outVar));
+ }
+ }
+
+ // Create struct with all field values
+ return builder.emitMakeStruct(returnStruct, fieldValues);
+ }
+ else if (outKeys.getCount())
+ {
+ // Single return value
+ if (resultKey)
+ {
+ return callResult;
+ }
+ else
+ {
+ // Get the out var from the first out parameter
+ for (auto& info : paramInfos)
+ {
+ if (info.isOut && info.outVar)
+ {
+ return builder.emitLoad(info.outVar);
+ }
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+// Transfer decorations from original to new function
+static void transferFunctionDecorations(
+ IRFunc* func,
+ IRFunc* newFunc,
+ IRBuilder& builder,
+ IRCloneEnv& cloneEnv)
+{
+ // Copy all decorations including layout - we'll keep the original layout
+ for (auto decor : func->getDecorations())
+ {
+ cloneDecoration(&cloneEnv, decor, newFunc, builder.getModule());
+ }
+}
+
+// Handle cleanup of original function if needed
+static void handleOriginalFunction(IRFunc* func, IRCall* callResult)
+{
+ // Count uses of original function
+ UInt useCount = 0;
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ useCount++;
+
+ if (useCount == 1)
+ {
+ inlineCall(callResult);
+
+ // Remove decorations from old function
+ List<IRDecoration*> decorationsToRemove;
+ for (auto decor : func->getDecorations())
+ {
+ if (as<IRKeepAliveDecoration>(decor) || as<IREntryPointDecoration>(decor))
+ {
+ decorationsToRemove.add(decor);
+ }
+ }
+
+ for (auto decor : decorationsToRemove)
+ {
+ decor->removeFromParent();
+ }
+
+ func->removeAndDeallocate();
+ }
+}
+
+// Main function that orchestrates the transformation
+IRFunc* lowerOutParameters(
+ IRFunc* func,
+ [[maybe_unused]] DiagnosticSink* sink,
+ bool alwaysUseReturnStruct)
+{
+ IRBuilder builder(func->getModule());
+ IRCloneEnv cloneEnv;
+
+ // Data structures for tracking parameter information
+ List<IRStructKey*> outKeys;
+ Dictionary<IRParam*, IRStructKey*> paramToKeyMap;
+ Dictionary<IRParam*, IRParam*> origToNewParamMap;
+
+ // Create result key for non-void return types
+ IRStructKey* resultKey = createResultKey(func, builder, outKeys);
+
+ // Collect parameter information
+ List<ParamInfo> paramInfos = collectParameterInfo(func, builder, outKeys, paramToKeyMap);
+
+ // Check if transformation is needed
+ if (!needsTransformation(paramInfos, alwaysUseReturnStruct))
+ return func;
+
+ // Create new function
+ auto newFunc = builder.createFunc();
+
+ // Transfer decorations to new function
+ transferFunctionDecorations(func, newFunc, builder, cloneEnv);
+
+ // Create return type
+ IRStructType* returnStruct = nullptr;
+ IRType* resultType = createReturnType(
+ func,
+ builder,
+ resultKey,
+ outKeys,
+ paramInfos,
+ alwaysUseReturnStruct,
+ returnStruct);
+
+ // Collect parameter types for new function
+ List<IRType*> newParamTypes;
+ for (auto& info : paramInfos)
+ {
+ if (!info.isOut || info.isInOut)
+ {
+ newParamTypes.add(info.valueType);
+ }
+ }
+
+ // Set function type
+ auto funcType = builder.getFuncType(newParamTypes, resultType);
+ newFunc->setFullType(funcType);
+
+ // Create function body
+ auto firstBlock = builder.createBlock();
+ newFunc->addBlock(firstBlock);
+ builder.setInsertInto(firstBlock);
+
+ // Create parameters and variables
+ createNewParameters(builder, paramInfos, cloneEnv, origToNewParamMap);
+
+ // Build call to original function
+ IRCall* callResult = buildOriginalFunctionCall(func, builder, paramInfos);
+
+ // Construct return value
+ IRInst* returnValue =
+ constructReturnValue(builder, callResult, resultKey, returnStruct, outKeys, paramInfos);
+
+ builder.emitReturn(returnValue);
+
+ // Handle cleanup of original function
+ handleOriginalFunction(func, callResult);
+
+ return newFunc;
+}
+
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-lower-out-parameters.h b/source/slang/slang-ir-lower-out-parameters.h
new file mode 100644
index 000000000..ac7514892
--- /dev/null
+++ b/source/slang/slang-ir-lower-out-parameters.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct IRModule;
+class DiagnosticSink;
+
+IRFunc* lowerOutParameters(IRFunc* func, DiagnosticSink* sink, bool alwaysUseReturnStruct);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index aa1ae3989..cccec7e05 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -398,6 +398,25 @@ bool canOperationBeSpecConst(
IRInst* const* fixedArgs,
IRUse* operands);
bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs);
+
+// most of <algorithm> doesn't work on out non-const iterators, so define this
+// version
+template<typename Range, typename Predicate>
+constexpr bool anyOf(Range&& range, Predicate&& pred)
+{
+ // Handle both const and non-const ranges
+ auto first = range.begin();
+ auto last = range.end();
+
+ for (; first != last; ++first)
+ {
+ if (pred(*first))
+ {
+ return true;
+ }
+ }
+ return false;
+}
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 89d7a666e..bf332aaf7 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4156,7 +4156,10 @@ static TypeCastStyle _getTypeStyleId(IRType* type)
{
return _getTypeStyleId(matrixType->getElementType());
}
+ // Try to simplify style if we can, otherwise just handle it unsimplified
auto style = getTypeStyle(type->getOp());
+ if (style == kIROp_Invalid)
+ style = type->getOp();
switch (style)
{
case kIROp_IntType: