summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-30 20:28:34 -0700
committerGitHub <noreply@github.com>2024-07-30 20:28:34 -0700
commit6e4b82741893be55f6216c31e19650029c667078 (patch)
treefefd4529c6066763653732d7f93ca5cf07027a76 /source
parent04e7327a2067c82db3eaef51955f211e148ac933 (diff)
Fixes for Metal ParameterBlock support. (#4752)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit.cpp8
-rw-r--r--source/slang/slang-ir-check-shader-parameter-type.cpp65
-rw-r--r--source/slang/slang-ir-check-shader-parameter-type.h13
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir-legalize-types.cpp171
-rw-r--r--source/slang/slang-ir.cpp71
-rw-r--r--source/slang/slang-legalize-types.cpp12
8 files changed, 309 insertions, 42 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index fe1a25b39..cb3e39dd4 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -859,6 +859,9 @@ DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.")
+// Metal
+DIAGNOSTIC(56100, Error, constantBufferInParameterBlockNotAllowedOnMetal, "nested 'ConstantBuffer' inside a 'ParameterBlock' is not supported on Metal, use 'ParameterBlock' instead.")
+
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index a8ed469fa..044f79531 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -15,6 +15,7 @@
#include "slang-ir-dce.h"
#include "slang-ir-diff-call.h"
#include "slang-ir-check-recursive-type.h"
+#include "slang-ir-check-shader-parameter-type.h"
#include "slang-ir-autodiff.h"
#include "slang-ir-defunctionalization.h"
#include "slang-ir-dll-export.h"
@@ -739,8 +740,15 @@ Result linkAndOptimizeIR(
}
if (targetProgram->getOptionSet().shouldRunNonEssentialValidation())
+ {
checkForRecursiveTypes(irModule, sink);
+ // For some targets, we are more restrictive about what types are allowed
+ // to be used as shader parameters in ConstantBuffer/ParameterBlock.
+ // We will check for these restrictions here.
+ checkForInvalidShaderParameterType(targetRequest, irModule, sink);
+ }
+
if (sink->getErrorCount() != 0)
return SLANG_FAIL;
diff --git a/source/slang/slang-ir-check-shader-parameter-type.cpp b/source/slang/slang-ir-check-shader-parameter-type.cpp
new file mode 100644
index 000000000..71833c838
--- /dev/null
+++ b/source/slang/slang-ir-check-shader-parameter-type.cpp
@@ -0,0 +1,65 @@
+#include "slang-ir-check-shader-parameter-type.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+ void checkForInvalidShaderParameterTypeForMetal(IRModule* module, DiagnosticSink* sink)
+ {
+ HashSet<IRInst*> workListSet;
+ List<IRInst*> workList;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() == kIROp_ParameterBlockType)
+ {
+ auto type = inst->getOperand(0);
+ if (workListSet.add(type))
+ workList.add(type);
+ // Diagnose an error on `ParameterBlock<ConstantBuffer<T>>`.
+ if (type->getOp() == kIROp_ConstantBufferType)
+ {
+ bool foundUseSite = false;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->sourceLoc.isValid())
+ {
+ sink->diagnose(user, Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
+ foundUseSite = true;
+ break;
+ }
+ }
+ if (!foundUseSite)
+ sink->diagnose(inst, Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
+ }
+ }
+ }
+ // Diagnose an error any any struct fields whose type is `ConstantBuffer<T>` if the
+ // struct is used inside a `ParameterBlock`.
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto type = workList[i];
+ if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ auto fieldType = field->getFieldType();
+ if (fieldType->getOp() == kIROp_ConstantBufferType)
+ {
+ sink->diagnose(field->getKey(), Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
+ }
+ if (workListSet.add(fieldType))
+ workList.add(fieldType);
+ }
+ }
+ }
+ }
+
+ void checkForInvalidShaderParameterType(
+ TargetRequest* target,
+ IRModule* module,
+ DiagnosticSink* sink)
+ {
+ if (isMetalTarget(target))
+ checkForInvalidShaderParameterTypeForMetal(module, sink);
+ }
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-check-shader-parameter-type.h b/source/slang/slang-ir-check-shader-parameter-type.h
new file mode 100644
index 000000000..2ecc95fd7
--- /dev/null
+++ b/source/slang/slang-ir-check-shader-parameter-type.h
@@ -0,0 +1,13 @@
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+ class TargetRequest;
+
+ void checkForInvalidShaderParameterType(
+ TargetRequest* targetReq,
+ IRModule* module,
+ DiagnosticSink* sink);
+}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 3aa7d1f64..795a79c28 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4075,11 +4075,19 @@ public:
IRInst* emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness);
IRInst* emitFieldExtract(
+ IRInst* base,
+ IRInst* fieldKey);
+
+ IRInst* emitFieldExtract(
IRType* type,
IRInst* base,
IRInst* field);
IRInst* emitFieldAddress(
+ IRInst* basePtr,
+ IRInst* fieldKey);
+
+ IRInst* emitFieldAddress(
IRType* type,
IRInst* basePtr,
IRInst* field);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 503b528b2..4d7759881 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1270,7 +1270,6 @@ static LegalVal legalizeFieldAddress(
default:
return LegalVal::simple(
builder->emitFieldAddress(
- type.getSimple(),
legalPtrOperand.getSimple(),
fieldKey));
}
@@ -1971,67 +1970,203 @@ static LegalVal legalizeDefaultConstruct(
}
}
+// If a legalized `val` has a different flavor than `type`, try to coerce it to `type`.
+//
+static LegalVal coerceToLegalType(
+ IRTypeLegalizationContext* context,
+ LegalType type,
+ LegalVal val)
+{
+ switch (type.flavor)
+ {
+ case LegalType::Flavor::none:
+ return LegalVal();
+ case LegalType::Flavor::simple:
+ {
+ if (val.flavor != LegalVal::Flavor::simple)
+ return val;
+ auto simpleVal = val.getSimple();
+ if (simpleVal->getDataType() == type.getSimple())
+ return val;
+
+ auto resultType = type.getSimple();
+ auto structType = as<IRStructType>(resultType);
+ if (!structType)
+ {
+ auto resultValueType = tryGetPointedToType(context->builder, resultType);
+ if (!resultValueType)
+ return val;
+ auto valValueType = tryGetPointedToType(context->builder, simpleVal->getDataType());
+ if (!valValueType)
+ return val;
+ if (resultValueType == valValueType)
+ return val;
+ auto loadedVal = context->builder->emitLoad(val.getSimple());
+ auto innerLegalVal = coerceToLegalType(context, LegalType::simple(resultValueType), LegalVal::simple(loadedVal));
+ return LegalVal::implicitDeref(innerLegalVal);
+ }
+ ShortList<IRInst*> fields;
+ for (auto field : structType->getFields())
+ {
+ if (as<IRVoidType>(field->getFieldType()))
+ continue;
+ auto fieldVal = coerceToLegalType(
+ context,
+ LegalType::simple(field->getFieldType()),
+ LegalVal::simple(context->builder->emitFieldExtract(simpleVal, field->getKey())));
+ fields.add(fieldVal.getSimple());
+ }
+ return LegalVal::simple(context->builder->emitMakeStruct(structType, (UInt)fields.getCount(), fields.getArrayView().getBuffer()));
+ }
+ case LegalType::Flavor::implicitDeref:
+ {
+ auto innerVal = val;
+ if (innerVal.flavor == LegalVal::Flavor::implicitDeref)
+ innerVal = innerVal.getImplicitDeref();
+ else if (innerVal.flavor == LegalVal::Flavor::simple)
+ innerVal = LegalVal::simple(context->builder->emitLoad(innerVal.getSimple()));
+ innerVal = coerceToLegalType(context, type.getImplicitDeref()->valueType, innerVal);
+ return LegalVal::implicitDeref(innerVal);
+ }
+ case LegalType::Flavor::pair:
+ {
+ if (val.flavor == LegalVal::Flavor::pair)
+ return val;
+ else if (val.flavor == LegalVal::Flavor::simple)
+ {
+ auto pairType = type.getPair();
+ auto pairInfo = pairType->pairInfo;
+ LegalVal ordinaryVal = coerceToLegalType(context, pairType->ordinaryType, val);
+ LegalVal specialVal = coerceToLegalType(context, pairType->specialType, val);
+ return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
+ }
+ else if (val.flavor == LegalVal::Flavor::implicitDeref)
+ {
+ LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref());
+ return LegalVal::implicitDeref(innerVal);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unhandled legal type coercion");
+ UNREACHABLE_RETURN(LegalVal());
+ }
+ }
+ case LegalType::Flavor::tuple:
+ {
+ if (val.flavor == LegalVal::Flavor::tuple)
+ return val;
+ else if (val.flavor == LegalVal::Flavor::simple)
+ {
+ auto tupleType = type.getTuple();
+ RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal();
+ auto simpleVal = val.getSimple();
+ for (auto elem : tupleType->elements)
+ {
+ IRInst* elementVal = nullptr;
+ if (as<IRPtrTypeBase>(simpleVal->getDataType()) || as<IRPointerLikeType>(simpleVal->getDataType()))
+ elementVal = context->builder->emitFieldAddress(simpleVal, elem.key);
+ else
+ elementVal = context->builder->emitFieldExtract(simpleVal, elem.key);
+ LegalVal legalElementVal = coerceToLegalType(context, elem.type, LegalVal::simple(elementVal));
+ TuplePseudoVal::Element tupleElem;
+ tupleElem.key = elem.key;
+ tupleElem.val = legalElementVal;
+ tupleVal->elements.add(tupleElem);
+ }
+ return LegalVal::tuple(tupleVal);
+ }
+ else if (val.flavor == LegalVal::Flavor::implicitDeref)
+ {
+ LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref());
+ return LegalVal::implicitDeref(innerVal);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unhandled legal type coercion");
+ UNREACHABLE_RETURN(LegalVal());
+ }
+ }
+ default:
+ return val;
+ }
+}
+
static LegalVal legalizeInst(
IRTypeLegalizationContext* context,
IRInst* inst,
LegalType type,
ArrayView<LegalVal> args)
{
+ LegalVal result = LegalVal();
switch (inst->getOp())
{
case kIROp_Load:
- return legalizeLoad(context, args[0]);
+ result = legalizeLoad(context, args[0]);
+ break;
case kIROp_GetValueFromBoundInterface:
- return args[0];
+ result = args[0];
+ break;
case kIROp_FieldAddress:
- return legalizeFieldAddress(context, type, args[0], args[1]);
+ result = legalizeFieldAddress(context, type, args[0], args[1]);
+ break;
case kIROp_FieldExtract:
- return legalizeFieldExtract(context, type, args[0], args[1]);
+ result = legalizeFieldExtract(context, type, args[0], args[1]);
+ break;
case kIROp_GetElement:
- return legalizeGetElement(context, type, args[0], args[1]);
+ result = legalizeGetElement(context, type, args[0], args[1]);
+ break;
case kIROp_GetElementPtr:
- return legalizeGetElementPtr(context, type, args[0], args[1]);
+ result = legalizeGetElementPtr(context, type, args[0], args[1]);
+ break;
case kIROp_Store:
- return legalizeStore(context, args[0], args[1]);
+ result = legalizeStore(context, args[0], args[1]);
+ break;
case kIROp_Call:
- return legalizeCall(context, (IRCall*)inst);
+ result = legalizeCall(context, (IRCall*)inst);
+ break;
case kIROp_Return:
- return legalizeRetVal(context, args[0], (IRReturn*)inst);
+ result = legalizeRetVal(context, args[0], (IRReturn*)inst);
+ break;
case kIROp_DebugVar:
- return legalizeDebugVar(context, type, (IRDebugVar*)inst);
+ result = legalizeDebugVar(context, type, (IRDebugVar*)inst);
+ break;
case kIROp_DebugValue:
- return legalizeDebugValue(context, args[0], args[1], (IRDebugValue*)inst);
+ result = legalizeDebugValue(context, args[0], args[1], (IRDebugValue*)inst);
+ break;
case kIROp_MakeStruct:
- return legalizeMakeStruct(
+ result = legalizeMakeStruct(
context,
type,
args.getBuffer(),
inst->getOperandCount());
+ break;
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
- return legalizeMakeArray(
+ result = legalizeMakeArray(
context,
type,
args.getBuffer(),
inst->getOperandCount(),
inst->getOp());
+ break;
case kIROp_DefaultConstruct:
- return legalizeDefaultConstruct(
+ result = legalizeDefaultConstruct(
context,
type);
-
+ break;
case kIROp_unconditionalBranch:
case kIROp_loop:
- return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst);
+ result = legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst);
+ break;
case kIROp_undefined:
return LegalVal();
case kIROp_GpuForeach:
@@ -2042,6 +2177,8 @@ static LegalVal legalizeInst(
SLANG_UNEXPECTED("non-simple operand(s)!");
break;
}
+ result = coerceToLegalType(context, type, result);
+ return result;
}
static UnownedStringSlice findNameHint(IRInst* inst)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 88065cedc..1fc15f185 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4870,6 +4870,27 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitFieldExtract(IRInst* base, IRInst* fieldKey)
+ {
+ IRType* resultType = nullptr;
+ auto valueType = base->getDataType();
+ auto structType = as<IRStructType>(valueType);
+ SLANG_RELEASE_ASSERT(structType);
+ for (auto child : valueType->getChildren())
+ {
+ auto field = as<IRStructField>(child);
+ if (!field)
+ continue;
+ if (field->getKey() == fieldKey)
+ {
+ resultType = field->getFieldType();
+ break;
+ }
+ }
+ SLANG_RELEASE_ASSERT(resultType);
+ return emitFieldExtract(resultType, base, fieldKey);
+ }
+
IRInst* IRBuilder::emitFieldExtract(
IRType* type,
IRInst* base,
@@ -4903,6 +4924,40 @@ namespace Slang
}
IRInst* IRBuilder::emitFieldAddress(
+ IRInst* basePtr,
+ IRInst* fieldKey)
+ {
+ AddressSpace addrSpace = AddressSpace::Generic;
+ IRInst* valueType = nullptr;
+ auto basePtrType = unwrapAttributedType(basePtr->getDataType());
+ if (auto ptrType = as<IRPtrTypeBase>(basePtrType))
+ {
+ addrSpace = ptrType->getAddressSpace();
+ valueType = ptrType->getValueType();
+ }
+ else if (auto ptrLikeType = as<IRPointerLikeType>(basePtrType))
+ {
+ valueType = ptrLikeType->getElementType();
+ }
+ IRType* resultType = nullptr;
+ auto structType = as<IRStructType>(valueType);
+ SLANG_RELEASE_ASSERT(structType);
+ for (auto child : valueType->getChildren())
+ {
+ auto field = as<IRStructField>(child);
+ if (!field)
+ continue;
+ if (field->getKey() == fieldKey)
+ {
+ resultType = field->getFieldType();
+ break;
+ }
+ }
+ SLANG_RELEASE_ASSERT(resultType);
+ return emitFieldAddress(getPtrType(kIROp_PtrType, resultType, addrSpace), basePtr, fieldKey);
+ }
+
+ IRInst* IRBuilder::emitFieldAddress(
IRType* type,
IRInst* base,
IRInst* field)
@@ -5080,23 +5135,9 @@ namespace Slang
{
for (auto access : accessChain)
{
- auto basePtrType = cast<IRPtrTypeBase>(basePtr->getDataType());
- auto valueType = unwrapAttributedType(basePtrType->getValueType());
- IRType* resultType = nullptr;
if (auto structKey = as<IRStructKey>(access))
{
- auto structType = as<IRStructType>(valueType);
- SLANG_RELEASE_ASSERT(structType);
- for (auto field : structType->getFields())
- {
- if (field->getKey() == structKey)
- {
- resultType = field->getFieldType();
- break;
- }
- }
- SLANG_RELEASE_ASSERT(resultType);
- basePtr = emitFieldAddress(getPtrType(kIROp_PtrType, resultType, basePtrType->getAddressSpace()), basePtr, structKey);
+ basePtr = emitFieldAddress(basePtr, structKey);
}
else
{
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 66c0044b6..aa69bac79 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -467,22 +467,14 @@ struct TupleTypeBuilder
IRBuilder* builder = context->getBuilder();
IRStructType* ordinaryStructType = builder->createStructType();
ordinaryStructType->sourceLoc = originalStructType->sourceLoc;
- copyNameHintAndDebugDecorations(ordinaryStructType, originalStructType);
+ originalStructType->transferDecorationsTo(ordinaryStructType);
+ copyNameHintAndDebugDecorations(originalStructType, ordinaryStructType);
// The new struct type will appear right after the original in the IR,
// so that we can be sure any instruction that could reference the
// original can also reference the new one.
ordinaryStructType->insertAfter(originalStructType);
- // Mark the original type for removal once all the other legalization
- // activity is completed. This is necessary because both the original
- // and replacement type have the same mangled name, so they would
- // collide.
- //
- // (Also, the original type wasn't legal - that was the whole point...)
- originalStructType->removeFromParent();
- context->replacedInstructions.add(originalStructType);
-
for(auto ee : ordinaryElements)
{
// We will ensure that all the original fields are represented,