summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp168
1 files changed, 119 insertions, 49 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 3ef94d415..0472e44df 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
+#include "slang-ir-layout.h"
namespace Slang
{
@@ -17,13 +18,15 @@ namespace Slang
IRFunc* convertOriginalToLowered = nullptr;
IRFunc* convertLoweredToOriginal = nullptr;
};
- Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo;
- Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo;
+
+ Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count];
+ Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count];
SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ TargetRequest* target;
- LoweredElementTypeContext(SlangMatrixLayoutMode inDefaultMatrixLayout)
- : defaultMatrixLayout(inDefaultMatrixLayout)
+ LoweredElementTypeContext(TargetRequest* target, SlangMatrixLayoutMode inDefaultMatrixLayout)
+ : target(target), defaultMatrixLayout(inDefaultMatrixLayout)
{}
IRFunc* createMatrixUnpackFunc(
@@ -200,7 +203,18 @@ namespace Slang
return func;
}
- LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type)
+ const char* getLayoutName(IRTypeLayoutRuleName name)
+ {
+ switch (name)
+ {
+ case IRTypeLayoutRuleName::Std140: return "std140";
+ case IRTypeLayoutRuleName::Std430: return "std430";
+ case IRTypeLayoutRuleName::Natural: return "natural";
+ default: return "default";
+ }
+ }
+
+ LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules)
{
IRBuilder builder(type);
builder.setInsertAfter(type);
@@ -224,12 +238,20 @@ namespace Slang
nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount());
if (isColMajor)
nameSB << "_ColMajor";
+ nameSB << getLayoutName(rules->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
auto structKey = builder.createStructKey();
builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
auto vectorType = builder.getVectorType(matrixType->getElementType(),
isColMajor?matrixType->getRowCount():matrixType->getColumnCount());
- auto arrayType = builder.getArrayType(vectorType, isColMajor?matrixType->getColumnCount():matrixType->getRowCount());
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(rules, vectorType, &elementSizeAlignment);
+ elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+
+ auto arrayType = builder.getArrayType(
+ vectorType,
+ isColMajor?matrixType->getColumnCount():matrixType->getRowCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.size));
builder.createStructField(loweredType, structKey, arrayType);
info.loweredType = loweredType;
@@ -241,45 +263,48 @@ namespace Slang
}
else if (auto arrayType = as<IRArrayType>(type))
{
- auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType());
-
- if (loweredInnerTypeInfo.loweredType != loweredInnerTypeInfo.originalType)
- {
- auto loweredType = builder.createStructType();
- info.loweredType = loweredType;
- StringBuilder nameSB;
- nameSB << "_ArrayStorage_";
- getTypeNameHint(nameSB, arrayType->getElementType());
- nameSB << getIntVal(arrayType->getElementCount());
- builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
- auto structKey = builder.createStructKey();
- builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- auto innerArrayType = builder.getArrayType(loweredInnerTypeInfo.loweredType, arrayType->getElementCount());
- builder.createStructField(loweredType, structKey, innerArrayType);
- info.loweredInnerArrayType = innerArrayType;
- info.loweredInnerStructKey = structKey;
- info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo);
- info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo);
- }
- else
+ auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules);
+ if (!loweredInnerTypeInfo.convertLoweredToOriginal && rules->ruleName == IRTypeLayoutRuleName::Natural)
{
info.loweredType = type;
+ return info;
}
+ auto loweredType = builder.createStructType();
+ info.loweredType = loweredType;
+ StringBuilder nameSB;
+ nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_";
+ getTypeNameHint(nameSB, arrayType->getElementType());
+ nameSB << getIntVal(arrayType->getElementCount());
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(rules, loweredType, &elementSizeAlignment);
+ elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+ auto innerArrayType = builder.getArrayType(
+ loweredInnerTypeInfo.loweredType,
+ arrayType->getElementCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.size));
+ builder.createStructField(loweredType, structKey, innerArrayType);
+ info.loweredInnerArrayType = innerArrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo);
+ info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo);
+
return info;
}
else if (auto structType = as<IRStructType>(type))
{
- bool hasNonTrivialField = false;
List<LoweredElementTypeInfo> fieldLoweredTypeInfo;
+ bool isTrivial = true;
for (auto field : structType->getFields())
{
- auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType());
+ auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
- if (loweredFieldTypeInfo.loweredType != loweredFieldTypeInfo.originalType)
- hasNonTrivialField = true;
+ if (loweredFieldTypeInfo.convertLoweredToOriginal || rules->ruleName != IRTypeLayoutRuleName::Natural)
+ isTrivial = false;
}
-
- if (!hasNonTrivialField)
+ if (isTrivial)
{
info.loweredType = type;
return info;
@@ -288,10 +313,9 @@ namespace Slang
auto loweredType = builder.createStructType();
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_Storage";
+ nameSB << "_" << getLayoutName(rules->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
-
// Create fields.
{
Index fieldId = 0;
@@ -340,7 +364,7 @@ namespace Slang
Index fieldId = 0;
for (auto field : structType->getFields())
{
- auto fieldVal = builder.emitFieldExtract(type, param, field->getKey());
+ auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey());
auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered
? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal)
: fieldVal;
@@ -358,14 +382,16 @@ namespace Slang
return info;
}
- LoweredElementTypeInfo getLoweredTypeInfo(IRType* type)
+ LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules)
{
LoweredElementTypeInfo info;
- if (loweredTypeInfo.tryGetValue(type, info))
+ if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info))
return info;
- info = getLoweredTypeInfoImpl(type);
- loweredTypeInfo[type] = info;
- mapLoweredTypeToInfo[info.loweredType] = info;
+ info = getLoweredTypeInfoImpl(type, rules);
+ IRSizeAndAlignment sizeAlignment;
+ getSizeAndAlignment(rules, info.loweredType, &sizeAlignment);
+ loweredTypeInfo[(int)rules->ruleName].set(type, info);
+ mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info);
return info;
}
@@ -389,6 +415,12 @@ namespace Slang
return nullptr;
}
+ struct MatrixAddrWorkItem
+ {
+ IRInst* matrixAddrInst;
+ IRTypeLayoutRules* layoutRules;
+ };
+
void processModule(IRModule* module)
{
IRBuilder builder(module);
@@ -414,13 +446,15 @@ namespace Slang
// Maintain a pending work list of all matrix addresses, and try to lower them out of existance
// after everything else has been lowered.
- List<IRInst*> matrixAddrInsts;
+
+ List<MatrixAddrWorkItem> matrixAddrInsts;
for (auto bufferTypeInfo : bufferTypeInsts)
{
auto bufferType = bufferTypeInfo.bufferType;
auto elementType = bufferTypeInfo.elementType;
- auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType);
+ auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType);
+ auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules);
// If the lowered type is the same as original type, no change is required.
if (!loweredBufferElementTypeInfo.convertLoweredToOriginal)
@@ -450,7 +484,7 @@ namespace Slang
auto ptrVal = ptrValsWorkList[i];
auto oldPtrType = ptrVal->getFullType();
auto originalElementType = oldPtrType->getOperand(0);
- auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType);
+ auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, layoutRules);
if (!loweredElementTypeInfo.convertLoweredToOriginal)
continue;
@@ -517,7 +551,7 @@ namespace Slang
// We are tring to get a pointer to a lowered matrix element.
// We process this insts at a later phase.
SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
- matrixAddrInsts.add(user);
+ matrixAddrInsts.add(MatrixAddrWorkItem{ user, layoutRules });
}
else
{
@@ -544,19 +578,24 @@ namespace Slang
bufferType->removeAndDeallocate();
}
+ // Process all matrix address uses.
lowerMatrixAddresses(module, matrixAddrInsts);
}
// Lower all getElementPtr insts of a lowered matrix out of existance.
- void lowerMatrixAddresses(IRModule* module, List<IRInst*>& matrixAddrInsts)
+ void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts)
{
IRBuilder builder(module);
- for (auto majorAddr : matrixAddrInsts)
+ for (auto workItem : matrixAddrInsts)
{
+ auto majorAddr = workItem.matrixAddrInst;
+ auto layoutRules = workItem.layoutRules;
+
+ int layoutRuleName = (int)layoutRules->ruleName;
auto majorGEP = as<IRGetElementPtr>(majorAddr);
SLANG_ASSERT(majorGEP);
auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType();
- auto matrixTypeInfo = mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
+ auto matrixTypeInfo = mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType);
SLANG_ASSERT(matrixTypeInfo);
auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
auto rowCount = getIntVal(matrixType->getRowCount());
@@ -652,7 +691,38 @@ namespace Slang
SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getDefaultMatrixLayoutMode();
if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(defaultMatrixMode);
+ LoweredElementTypeContext context(target, defaultMatrixMode);
context.processModule(module);
}
+
+
+ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetRequest* target, IRType* bufferType)
+ {
+ if (!isKhronosTarget(target))
+ return IRTypeLayoutRules::getNatural();
+
+ // If we are just emitting GLSL, we can just use the general layout rule.
+ if (!target->shouldEmitSPIRVDirectly())
+ return IRTypeLayoutRules::getNatural();
+
+ // If the user specified a scalar buffer layout, then just use that.
+ if (target->getForceGLSLScalarBufferLayout())
+ return IRTypeLayoutRules::getNatural();
+
+ // The default behavior is to use std140 for constant buffers and std430 for other buffers.
+ switch (bufferType->getOp())
+ {
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRWStructuredBufferType:
+ case kIROp_HLSLAppendStructuredBufferType:
+ case kIROp_HLSLConsumeStructuredBufferType:
+ case kIROp_HLSLRasterizerOrderedStructuredBufferType:
+ return IRTypeLayoutRules::getStd430();
+ case kIROp_ConstantBufferType:
+ case kIROp_ParameterBlockType:
+ return IRTypeLayoutRules::getStd140();
+ }
+ return IRTypeLayoutRules::getNatural();
+ }
+
}