summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-spirv.cpp13
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp150
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp17
-rw-r--r--source/slang/slang-parameter-binding.cpp53
-rw-r--r--source/slang/slang-type-layout.cpp10
-rw-r--r--source/slang/slang-type-system-shared.h4
6 files changed, 196 insertions, 51 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index c618946ec..92fa507e0 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1279,8 +1279,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case AddressSpace::Uniform:
return SpvStorageClassUniform;
case AddressSpace::Input:
+ case AddressSpace::BuiltinInput:
return SpvStorageClassInput;
case AddressSpace::Output:
+ case AddressSpace::BuiltinOutput:
return SpvStorageClassOutput;
case AddressSpace::TaskPayloadWorkgroup:
return SpvStorageClassTaskPayloadWorkgroupEXT;
@@ -2688,7 +2690,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
IRBuilder builder(spvAsmBuiltinVar);
builder.setInsertBefore(spvAsmBuiltinVar);
auto varInst = getBuiltinGlobalVar(
- builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), AddressSpace::Input),
+ builder.getPtrType(
+ kIROp_PtrType,
+ spvAsmBuiltinVar->getDataType(),
+ AddressSpace::BuiltinInput),
kind,
spvAsmBuiltinVar);
registerInst(spvAsmBuiltinVar, varInst);
@@ -4214,7 +4219,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto addrSpace = ptrType->getAddressSpace();
if (addrSpace != AddressSpace::Input &&
- addrSpace != AddressSpace::Output)
+ addrSpace != AddressSpace::Output &&
+ addrSpace != AddressSpace::BuiltinInput &&
+ addrSpace != AddressSpace::BuiltinOutput)
continue;
}
}
@@ -4995,7 +5002,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (!ptrType)
return;
auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace == AddressSpace::Input)
+ if (addrSpace == AddressSpace::Input || addrSpace == AddressSpace::BuiltinInput)
{
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
{
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index dd62ca02c..7f67c9254 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -8,6 +8,22 @@
namespace Slang
{
+
+struct TypeLoweringConfig
+{
+ AddressSpace addressSpace;
+ IRTypeLayoutRules* layoutRule;
+ bool operator==(const TypeLoweringConfig& other) const
+ {
+ return addressSpace == other.addressSpace && layoutRule == other.layoutRule;
+ }
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule));
+ }
+};
+TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType);
+
struct LoweredElementTypeContext
{
static const IRIntegerValue kMaxArraySizeToUnroll = 32;
@@ -67,9 +83,13 @@ struct LoweredElementTypeContext
ConversionMethod convertLoweredToOriginal;
};
- Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count];
- Dictionary<IRType*, LoweredElementTypeInfo>
- mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count];
+ struct LoweredTypeMap : RefObject
+ {
+ Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo;
+ Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo;
+ };
+
+ Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps;
struct ConversionMethodKey
{
@@ -392,7 +412,7 @@ struct LoweredElementTypeContext
return 4;
}
- LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules)
+ LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
builder.setInsertAfter(type);
@@ -409,7 +429,7 @@ struct LoweredElementTypeContext
// For other targets, we only lower the matrix types if they differ from the default
// matrix layout.
if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- rules->ruleName == IRTypeLayoutRuleName::Natural)
+ config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
{
info.loweredType = type;
return info;
@@ -426,12 +446,12 @@ struct LoweredElementTypeContext
<< getIntVal(matrixType->getColumnCount());
if (isColMajor)
nameSB << "_ColMajor";
- nameSB << getLayoutName(rules->ruleName);
+ nameSB << getLayoutName(config.layoutRule->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
auto structKey = builder.createStructKey();
builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
- if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
options.use16ByteArrayElementForConstantBuffer)
{
// For constant buffer layout, we need to use 16-byte aligned vector if
@@ -443,8 +463,12 @@ struct LoweredElementTypeContext
auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment);
- elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ vectorType,
+ &elementSizeAlignment);
+ elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
auto arrayType = builder.getArrayType(
vectorType,
@@ -463,9 +487,9 @@ struct LoweredElementTypeContext
}
else if (auto arrayType = as<IRArrayType>(type))
{
- auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules);
+ auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config);
- if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
options.use16ByteArrayElementForConstantBuffer)
{
// For constant buffer layout, we need to use 16-byte-aligned vector if
@@ -510,10 +534,10 @@ struct LoweredElementTypeContext
}
}
- // For spirv backend, we always want to lower all array types, even if the element type
- // comes out the same. This is because different layout rules may have different array
- // stride requirements.
- if (!target->shouldEmitSPIRVDirectly())
+ // For spirv backend, we always want to lower all array types for non-varying
+ // parameters, even if the element type comes out the same. This is because different
+ // layout rules may have different array stride requirements.
+ if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input)
{
if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
@@ -525,7 +549,7 @@ struct LoweredElementTypeContext
auto loweredType = builder.createStructType();
info.loweredType = loweredType;
StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_";
+ nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
nameSB << getIntVal(arrayType->getElementCount());
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
@@ -534,10 +558,10 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- rules,
+ config.layoutRule,
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
- elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+ elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayType(
loweredInnerTypeInfo.loweredType,
arrayType->getElementCount(),
@@ -566,10 +590,10 @@ struct LoweredElementTypeContext
bool isTrivial = true;
for (auto field : structType->getFields())
{
- auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules);
+ auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
if (loweredFieldTypeInfo.convertLoweredToOriginal ||
- rules->ruleName != IRTypeLayoutRuleName::Natural)
+ config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural)
isTrivial = false;
}
@@ -589,7 +613,7 @@ struct LoweredElementTypeContext
auto loweredType = builder.createStructType();
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_" << getLayoutName(rules->ruleName);
+ nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
// Create fields.
@@ -748,24 +772,40 @@ struct LoweredElementTypeContext
return info;
}
- LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules)
+ LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config)
+ {
+ RefPtr<LoweredTypeMap> map;
+ if (loweredTypeInfoMaps.tryGetValue(config, map))
+ return *map;
+ map = new LoweredTypeMap();
+ loweredTypeInfoMaps.add(config, map);
+ return *map;
+ }
+
+ LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, TypeLoweringConfig config)
{
// If `type` is already a lowered type, no more lowering is required.
LoweredElementTypeInfo info;
- if (mapLoweredTypeToInfo->tryGetValue(type))
+ auto& map = getTypeLoweringMap(config);
+ auto& mapLoweredTypeToInfo = map.mapLoweredTypeToInfo;
+ auto& loweredTypeInfo = map.loweredTypeInfo;
+ if (mapLoweredTypeToInfo.tryGetValue(type))
{
info.originalType = type;
info.loweredType = type;
return info;
}
-
- if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info))
+ if (loweredTypeInfo.tryGetValue(type, info))
return info;
- info = getLoweredTypeInfoImpl(type, rules);
+ info = getLoweredTypeInfoImpl(type, config);
IRSizeAndAlignment sizeAlignment;
- getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment);
- loweredTypeInfo[(int)rules->ruleName].set(type, info);
- mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info);
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ info.loweredType,
+ &sizeAlignment);
+ loweredTypeInfo.set(type, info);
+ mapLoweredTypeToInfo.set(info.loweredType, info);
conversionMethodMap[{info.originalType, info.loweredType}] = info.convertLoweredToOriginal;
conversionMethodMap[{info.loweredType, info.originalType}] = info.convertOriginalToLowered;
return info;
@@ -802,7 +842,7 @@ struct LoweredElementTypeContext
struct MatrixAddrWorkItem
{
IRInst* matrixAddrInst;
- IRTypeLayoutRules* layoutRules;
+ TypeLoweringConfig config;
};
void processModule(IRModule* module)
@@ -812,17 +852,25 @@ struct LoweredElementTypeContext
{
IRType* bufferType;
IRType* elementType;
+ bool shouldWrapArrayInStruct = false;
};
List<BufferTypeInfo> bufferTypeInsts;
for (auto globalInst : module->getGlobalInsts())
{
IRType* elementType = nullptr;
+
if (options.lowerBufferPointer)
{
- if (auto ptrType = as<IRPtrType>(globalInst))
+ if (auto ptrType = as<IRPtrTypeBase>(globalInst))
{
- if (ptrType->getAddressSpace() == AddressSpace::UserPointer)
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::UserPointer:
+ case AddressSpace::Input:
+ case AddressSpace::Output:
elementType = ptrType->getValueType();
+ break;
+ }
}
}
else
@@ -849,8 +897,8 @@ struct LoweredElementTypeContext
{
auto bufferType = bufferTypeInfo.bufferType;
auto elementType = bufferTypeInfo.elementType;
- auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType);
- auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules);
+ auto config = getTypeLoweringConfigForBuffer(target, bufferType);
+ auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config);
// If the lowered type is the same as original type, no change is required.
if (loweredBufferElementTypeInfo.loweredType ==
@@ -902,18 +950,18 @@ struct LoweredElementTypeContext
builder.setInsertBefore(ptrVal);
auto newArrayPtrVal = fieldAddr->getBase();
auto loweredInnerType =
- getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
+ getLoweredTypeInfo(unsizedArrayType->getElementType(), config);
IRSizeAndAlignment arrayElementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- layoutRules,
+ config.layoutRule,
loweredInnerType.loweredType,
&arrayElementSizeAlignment);
IRSizeAndAlignment baseSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- layoutRules,
+ config.layoutRule,
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
&baseSizeAlignment);
@@ -998,7 +1046,7 @@ struct LoweredElementTypeContext
if (!loweredElementTypeInfo.loweredType)
{
loweredElementTypeInfo =
- getLoweredTypeInfo((IRType*)originalElementType, layoutRules);
+ getLoweredTypeInfo((IRType*)originalElementType, config);
}
if (!loweredElementTypeInfo.convertLoweredToOriginal)
@@ -1087,7 +1135,7 @@ struct LoweredElementTypeContext
// We are tring to get a pointer to a lowered matrix element.
// We process this insts at a later phase.
SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
- matrixAddrInsts.add(MatrixAddrWorkItem{user, layoutRules});
+ matrixAddrInsts.add(MatrixAddrWorkItem{user, config});
}
else
{
@@ -1160,15 +1208,12 @@ struct LoweredElementTypeContext
for (auto workItem : matrixAddrInsts)
{
auto majorAddr = workItem.matrixAddrInst;
- auto layoutRules = workItem.layoutRules;
-
- int layoutRuleName = (int)layoutRules->ruleName;
auto majorGEP = as<IRGetElementPtr>(majorAddr);
SLANG_ASSERT(majorGEP);
auto loweredMatrixType =
cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType();
- auto matrixTypeInfo =
- mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType);
+ auto matrixTypeInfo = getTypeLoweringMap(workItem.config)
+ .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
SLANG_ASSERT(matrixTypeInfo);
auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
auto rowCount = getIntVal(matrixType->getRowCount());
@@ -1379,4 +1424,21 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
return IRTypeLayoutRules::getNatural();
}
+TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType)
+{
+ AddressSpace addrSpace = AddressSpace::Generic;
+ if (auto ptrType = as<IRPtrTypeBase>(bufferType))
+ {
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::Input:
+ case AddressSpace::Output:
+ addrSpace = AddressSpace::Input;
+ break;
+ }
+ }
+ auto rules = getTypeLayoutRuleForBuffer(target, bufferType);
+ return TypeLoweringConfig{addrSpace, rules};
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index c9764b203..4c01e5640 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -309,7 +309,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Skip load's for referenced `Input` variables since a ref implies
// passing as is, which needs to be a pointer (pass as is).
if (user->getDataType() && user->getDataType()->getOp() == kIROp_RefType &&
- addressSpace == AddressSpace::Input)
+ (addressSpace == AddressSpace::Input ||
+ addressSpace == AddressSpace::BuiltinInput))
{
builder.replaceOperand(use, addr);
continue;
@@ -431,7 +432,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
if (semanticName == "sv_pointsize")
- addressSpace = AddressSpace::Input;
+ addressSpace = AddressSpace::BuiltinInput;
}
}
@@ -661,6 +662,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
"resolve a storage class address space.");
}
}
+
+ switch (result)
+ {
+ case AddressSpace::Input:
+ if (varLayout->findSystemValueSemanticAttr())
+ result = AddressSpace::BuiltinInput;
+ break;
+ case AddressSpace::Output:
+ if (varLayout->findSystemValueSemanticAttr())
+ result = AddressSpace::BuiltinOutput;
+ break;
+ }
return result;
}
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 7fb719053..33fa24f11 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2262,12 +2262,63 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
return ptrTypeLayout;
}
+ else if (auto optionalType = as<OptionalType>(type))
+ {
+ Array<Type*, 2> types =
+ makeArray(optionalType->getValueType(), context->getASTBuilder()->getBoolType());
+ auto tupleType = context->getASTBuilder()->getTupleType(types.getView());
+ return processEntryPointVaryingParameter(context, tupleType, state, varLayout);
+ }
+ else if (auto tupleType = as<TupleType>(type))
+ {
+ RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
+ structLayout->type = type;
+ for (Index i = 0; i < tupleType->getMemberCount(); i++)
+ {
+ auto fieldType = tupleType->getMember(i);
+ RefPtr<VarLayout> fieldVarLayout = new VarLayout();
+
+ // We don't really have a "field" decl, so just use the tuple-typed decl
+ // itself as the varDecl of the elements.
+ auto fieldDecl = (VarDeclBase*)varLayout->varDecl.getDecl();
+ fieldVarLayout->varDecl = fieldDecl;
+
+ structLayout->fields.add(fieldVarLayout);
+
+ auto fieldTypeLayout = processEntryPointVaryingParameterDecl(
+ context,
+ fieldDecl,
+ fieldType,
+ state,
+ fieldVarLayout);
+
+ if (!fieldTypeLayout)
+ {
+ getSink(context)->diagnose(
+ varLayout->varDecl,
+ Diagnostics::notValidVaryingParameter,
+ fieldType);
+ continue;
+ }
+ fieldVarLayout->typeLayout = fieldTypeLayout;
+
+ // Assign offsets in var layout for each resource kind of the type.
+ for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos)
+ {
+ auto kind = fieldTypeResInfo.kind;
+ auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind);
+ auto fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind);
+ fieldResInfo->index = structTypeResInfo->count.getFiniteValue();
+ structTypeResInfo->count += fieldTypeResInfo.count;
+ }
+ }
+ return structLayout;
+ }
// Catch declaration-reference types late in the sequence, since
// otherwise they will include all of the above cases...
else if (auto declRefType = as<DeclRefType>(type))
{
auto declRef = declRefType->getDeclRef();
-
if (auto structDeclRef = declRef.as<StructDecl>())
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index b54f813dd..17f1ae677 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -4785,10 +4785,18 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
type,
rules);
}
+ else if (auto optionalType = as<OptionalType>(type))
+ {
+ // OptionalType should be laid out the same way as Tuple<T, bool>.
+ Array<Type*, 2> types =
+ makeArray(optionalType->getValueType(), context.astBuilder->getBoolType());
+ auto tupleType = context.astBuilder->getTupleType(types.getView());
+ return _createTypeLayout(context, tupleType);
+ }
else if (auto tupleType = as<TupleType>(type))
{
// A `Tuple` type is laid out exactly the same way as a `struct` type,
- // except that we want have a declref to the field.
+ // except that we won't have a declref to the field.
StructTypeLayoutBuilder typeLayoutBuilder;
StructTypeLayoutBuilder pendingDataTypeLayoutBuilder;
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index 73370a4e6..f80267d2b 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -74,8 +74,12 @@ enum class AddressSpace : uint64_t
MetalObjectData,
// Corresponds to SPIR-V's SpvStorageClassInput
Input,
+ // Same as `Input`, but used for builtin input variables.
+ BuiltinInput,
// Corresponds to SPIR-V's SpvStorageClassOutput
Output,
+ // Same as `Output`, but used for builtin output variables.
+ BuiltinOutput,
// Corresponds to SPIR-V's SpvStorageClassTaskPayloadWorkgroupEXT
TaskPayloadWorkgroup,
// Corresponds to SPIR-V's SpvStorageClassFunction