diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-30 20:28:34 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-30 20:28:34 -0700 |
| commit | 6e4b82741893be55f6216c31e19650029c667078 (patch) | |
| tree | fefd4529c6066763653732d7f93ca5cf07027a76 /source | |
| parent | 04e7327a2067c82db3eaef51955f211e148ac933 (diff) | |
Fixes for Metal ParameterBlock support. (#4752)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-shader-parameter-type.cpp | 65 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-shader-parameter-type.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 171 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 71 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.cpp | 12 |
8 files changed, 309 insertions, 42 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index fe1a25b39..cb3e39dd4 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -859,6 +859,9 @@ DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0 DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.") +// Metal +DIAGNOSTIC(56100, Error, constantBufferInParameterBlockNotAllowedOnMetal, "nested 'ConstantBuffer' inside a 'ParameterBlock' is not supported on Metal, use 'ParameterBlock' instead.") + DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a8ed469fa..044f79531 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -15,6 +15,7 @@ #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" #include "slang-ir-check-recursive-type.h" +#include "slang-ir-check-shader-parameter-type.h" #include "slang-ir-autodiff.h" #include "slang-ir-defunctionalization.h" #include "slang-ir-dll-export.h" @@ -739,8 +740,15 @@ Result linkAndOptimizeIR( } if (targetProgram->getOptionSet().shouldRunNonEssentialValidation()) + { checkForRecursiveTypes(irModule, sink); + // For some targets, we are more restrictive about what types are allowed + // to be used as shader parameters in ConstantBuffer/ParameterBlock. + // We will check for these restrictions here. + checkForInvalidShaderParameterType(targetRequest, irModule, sink); + } + if (sink->getErrorCount() != 0) return SLANG_FAIL; diff --git a/source/slang/slang-ir-check-shader-parameter-type.cpp b/source/slang/slang-ir-check-shader-parameter-type.cpp new file mode 100644 index 000000000..71833c838 --- /dev/null +++ b/source/slang/slang-ir-check-shader-parameter-type.cpp @@ -0,0 +1,65 @@ +#include "slang-ir-check-shader-parameter-type.h" +#include "slang-ir-util.h" + +namespace Slang +{ + void checkForInvalidShaderParameterTypeForMetal(IRModule* module, DiagnosticSink* sink) + { + HashSet<IRInst*> workListSet; + List<IRInst*> workList; + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_ParameterBlockType) + { + auto type = inst->getOperand(0); + if (workListSet.add(type)) + workList.add(type); + // Diagnose an error on `ParameterBlock<ConstantBuffer<T>>`. + if (type->getOp() == kIROp_ConstantBufferType) + { + bool foundUseSite = false; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (user->sourceLoc.isValid()) + { + sink->diagnose(user, Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal); + foundUseSite = true; + break; + } + } + if (!foundUseSite) + sink->diagnose(inst, Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal); + } + } + } + // Diagnose an error any any struct fields whose type is `ConstantBuffer<T>` if the + // struct is used inside a `ParameterBlock`. + for (Index i = 0; i < workList.getCount(); i++) + { + auto type = workList[i]; + if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + auto fieldType = field->getFieldType(); + if (fieldType->getOp() == kIROp_ConstantBufferType) + { + sink->diagnose(field->getKey(), Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal); + } + if (workListSet.add(fieldType)) + workList.add(fieldType); + } + } + } + } + + void checkForInvalidShaderParameterType( + TargetRequest* target, + IRModule* module, + DiagnosticSink* sink) + { + if (isMetalTarget(target)) + checkForInvalidShaderParameterTypeForMetal(module, sink); + } +}
\ No newline at end of file diff --git a/source/slang/slang-ir-check-shader-parameter-type.h b/source/slang/slang-ir-check-shader-parameter-type.h new file mode 100644 index 000000000..2ecc95fd7 --- /dev/null +++ b/source/slang/slang-ir-check-shader-parameter-type.h @@ -0,0 +1,13 @@ +#pragma once + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + class TargetRequest; + + void checkForInvalidShaderParameterType( + TargetRequest* targetReq, + IRModule* module, + DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 3aa7d1f64..795a79c28 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4075,11 +4075,19 @@ public: IRInst* emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness); IRInst* emitFieldExtract( + IRInst* base, + IRInst* fieldKey); + + IRInst* emitFieldExtract( IRType* type, IRInst* base, IRInst* field); IRInst* emitFieldAddress( + IRInst* basePtr, + IRInst* fieldKey); + + IRInst* emitFieldAddress( IRType* type, IRInst* basePtr, IRInst* field); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 503b528b2..4d7759881 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1270,7 +1270,6 @@ static LegalVal legalizeFieldAddress( default: return LegalVal::simple( builder->emitFieldAddress( - type.getSimple(), legalPtrOperand.getSimple(), fieldKey)); } @@ -1971,67 +1970,203 @@ static LegalVal legalizeDefaultConstruct( } } +// If a legalized `val` has a different flavor than `type`, try to coerce it to `type`. +// +static LegalVal coerceToLegalType( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal val) +{ + switch (type.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + case LegalType::Flavor::simple: + { + if (val.flavor != LegalVal::Flavor::simple) + return val; + auto simpleVal = val.getSimple(); + if (simpleVal->getDataType() == type.getSimple()) + return val; + + auto resultType = type.getSimple(); + auto structType = as<IRStructType>(resultType); + if (!structType) + { + auto resultValueType = tryGetPointedToType(context->builder, resultType); + if (!resultValueType) + return val; + auto valValueType = tryGetPointedToType(context->builder, simpleVal->getDataType()); + if (!valValueType) + return val; + if (resultValueType == valValueType) + return val; + auto loadedVal = context->builder->emitLoad(val.getSimple()); + auto innerLegalVal = coerceToLegalType(context, LegalType::simple(resultValueType), LegalVal::simple(loadedVal)); + return LegalVal::implicitDeref(innerLegalVal); + } + ShortList<IRInst*> fields; + for (auto field : structType->getFields()) + { + if (as<IRVoidType>(field->getFieldType())) + continue; + auto fieldVal = coerceToLegalType( + context, + LegalType::simple(field->getFieldType()), + LegalVal::simple(context->builder->emitFieldExtract(simpleVal, field->getKey()))); + fields.add(fieldVal.getSimple()); + } + return LegalVal::simple(context->builder->emitMakeStruct(structType, (UInt)fields.getCount(), fields.getArrayView().getBuffer())); + } + case LegalType::Flavor::implicitDeref: + { + auto innerVal = val; + if (innerVal.flavor == LegalVal::Flavor::implicitDeref) + innerVal = innerVal.getImplicitDeref(); + else if (innerVal.flavor == LegalVal::Flavor::simple) + innerVal = LegalVal::simple(context->builder->emitLoad(innerVal.getSimple())); + innerVal = coerceToLegalType(context, type.getImplicitDeref()->valueType, innerVal); + return LegalVal::implicitDeref(innerVal); + } + case LegalType::Flavor::pair: + { + if (val.flavor == LegalVal::Flavor::pair) + return val; + else if (val.flavor == LegalVal::Flavor::simple) + { + auto pairType = type.getPair(); + auto pairInfo = pairType->pairInfo; + LegalVal ordinaryVal = coerceToLegalType(context, pairType->ordinaryType, val); + LegalVal specialVal = coerceToLegalType(context, pairType->specialType, val); + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + else if (val.flavor == LegalVal::Flavor::implicitDeref) + { + LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref()); + return LegalVal::implicitDeref(innerVal); + } + else + { + SLANG_UNEXPECTED("unhandled legal type coercion"); + UNREACHABLE_RETURN(LegalVal()); + } + } + case LegalType::Flavor::tuple: + { + if (val.flavor == LegalVal::Flavor::tuple) + return val; + else if (val.flavor == LegalVal::Flavor::simple) + { + auto tupleType = type.getTuple(); + RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); + auto simpleVal = val.getSimple(); + for (auto elem : tupleType->elements) + { + IRInst* elementVal = nullptr; + if (as<IRPtrTypeBase>(simpleVal->getDataType()) || as<IRPointerLikeType>(simpleVal->getDataType())) + elementVal = context->builder->emitFieldAddress(simpleVal, elem.key); + else + elementVal = context->builder->emitFieldExtract(simpleVal, elem.key); + LegalVal legalElementVal = coerceToLegalType(context, elem.type, LegalVal::simple(elementVal)); + TuplePseudoVal::Element tupleElem; + tupleElem.key = elem.key; + tupleElem.val = legalElementVal; + tupleVal->elements.add(tupleElem); + } + return LegalVal::tuple(tupleVal); + } + else if (val.flavor == LegalVal::Flavor::implicitDeref) + { + LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref()); + return LegalVal::implicitDeref(innerVal); + } + else + { + SLANG_UNEXPECTED("unhandled legal type coercion"); + UNREACHABLE_RETURN(LegalVal()); + } + } + default: + return val; + } +} + static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst, LegalType type, ArrayView<LegalVal> args) { + LegalVal result = LegalVal(); switch (inst->getOp()) { case kIROp_Load: - return legalizeLoad(context, args[0]); + result = legalizeLoad(context, args[0]); + break; case kIROp_GetValueFromBoundInterface: - return args[0]; + result = args[0]; + break; case kIROp_FieldAddress: - return legalizeFieldAddress(context, type, args[0], args[1]); + result = legalizeFieldAddress(context, type, args[0], args[1]); + break; case kIROp_FieldExtract: - return legalizeFieldExtract(context, type, args[0], args[1]); + result = legalizeFieldExtract(context, type, args[0], args[1]); + break; case kIROp_GetElement: - return legalizeGetElement(context, type, args[0], args[1]); + result = legalizeGetElement(context, type, args[0], args[1]); + break; case kIROp_GetElementPtr: - return legalizeGetElementPtr(context, type, args[0], args[1]); + result = legalizeGetElementPtr(context, type, args[0], args[1]); + break; case kIROp_Store: - return legalizeStore(context, args[0], args[1]); + result = legalizeStore(context, args[0], args[1]); + break; case kIROp_Call: - return legalizeCall(context, (IRCall*)inst); + result = legalizeCall(context, (IRCall*)inst); + break; case kIROp_Return: - return legalizeRetVal(context, args[0], (IRReturn*)inst); + result = legalizeRetVal(context, args[0], (IRReturn*)inst); + break; case kIROp_DebugVar: - return legalizeDebugVar(context, type, (IRDebugVar*)inst); + result = legalizeDebugVar(context, type, (IRDebugVar*)inst); + break; case kIROp_DebugValue: - return legalizeDebugValue(context, args[0], args[1], (IRDebugValue*)inst); + result = legalizeDebugValue(context, args[0], args[1], (IRDebugValue*)inst); + break; case kIROp_MakeStruct: - return legalizeMakeStruct( + result = legalizeMakeStruct( context, type, args.getBuffer(), inst->getOperandCount()); + break; case kIROp_MakeArray: case kIROp_MakeArrayFromElement: - return legalizeMakeArray( + result = legalizeMakeArray( context, type, args.getBuffer(), inst->getOperandCount(), inst->getOp()); + break; case kIROp_DefaultConstruct: - return legalizeDefaultConstruct( + result = legalizeDefaultConstruct( context, type); - + break; case kIROp_unconditionalBranch: case kIROp_loop: - return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst); + result = legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst); + break; case kIROp_undefined: return LegalVal(); case kIROp_GpuForeach: @@ -2042,6 +2177,8 @@ static LegalVal legalizeInst( SLANG_UNEXPECTED("non-simple operand(s)!"); break; } + result = coerceToLegalType(context, type, result); + return result; } static UnownedStringSlice findNameHint(IRInst* inst) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 88065cedc..1fc15f185 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4870,6 +4870,27 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitFieldExtract(IRInst* base, IRInst* fieldKey) + { + IRType* resultType = nullptr; + auto valueType = base->getDataType(); + auto structType = as<IRStructType>(valueType); + SLANG_RELEASE_ASSERT(structType); + for (auto child : valueType->getChildren()) + { + auto field = as<IRStructField>(child); + if (!field) + continue; + if (field->getKey() == fieldKey) + { + resultType = field->getFieldType(); + break; + } + } + SLANG_RELEASE_ASSERT(resultType); + return emitFieldExtract(resultType, base, fieldKey); + } + IRInst* IRBuilder::emitFieldExtract( IRType* type, IRInst* base, @@ -4903,6 +4924,40 @@ namespace Slang } IRInst* IRBuilder::emitFieldAddress( + IRInst* basePtr, + IRInst* fieldKey) + { + AddressSpace addrSpace = AddressSpace::Generic; + IRInst* valueType = nullptr; + auto basePtrType = unwrapAttributedType(basePtr->getDataType()); + if (auto ptrType = as<IRPtrTypeBase>(basePtrType)) + { + addrSpace = ptrType->getAddressSpace(); + valueType = ptrType->getValueType(); + } + else if (auto ptrLikeType = as<IRPointerLikeType>(basePtrType)) + { + valueType = ptrLikeType->getElementType(); + } + IRType* resultType = nullptr; + auto structType = as<IRStructType>(valueType); + SLANG_RELEASE_ASSERT(structType); + for (auto child : valueType->getChildren()) + { + auto field = as<IRStructField>(child); + if (!field) + continue; + if (field->getKey() == fieldKey) + { + resultType = field->getFieldType(); + break; + } + } + SLANG_RELEASE_ASSERT(resultType); + return emitFieldAddress(getPtrType(kIROp_PtrType, resultType, addrSpace), basePtr, fieldKey); + } + + IRInst* IRBuilder::emitFieldAddress( IRType* type, IRInst* base, IRInst* field) @@ -5080,23 +5135,9 @@ namespace Slang { for (auto access : accessChain) { - auto basePtrType = cast<IRPtrTypeBase>(basePtr->getDataType()); - auto valueType = unwrapAttributedType(basePtrType->getValueType()); - IRType* resultType = nullptr; if (auto structKey = as<IRStructKey>(access)) { - auto structType = as<IRStructType>(valueType); - SLANG_RELEASE_ASSERT(structType); - for (auto field : structType->getFields()) - { - if (field->getKey() == structKey) - { - resultType = field->getFieldType(); - break; - } - } - SLANG_RELEASE_ASSERT(resultType); - basePtr = emitFieldAddress(getPtrType(kIROp_PtrType, resultType, basePtrType->getAddressSpace()), basePtr, structKey); + basePtr = emitFieldAddress(basePtr, structKey); } else { diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 66c0044b6..aa69bac79 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -467,22 +467,14 @@ struct TupleTypeBuilder IRBuilder* builder = context->getBuilder(); IRStructType* ordinaryStructType = builder->createStructType(); ordinaryStructType->sourceLoc = originalStructType->sourceLoc; - copyNameHintAndDebugDecorations(ordinaryStructType, originalStructType); + originalStructType->transferDecorationsTo(ordinaryStructType); + copyNameHintAndDebugDecorations(originalStructType, ordinaryStructType); // The new struct type will appear right after the original in the IR, // so that we can be sure any instruction that could reference the // original can also reference the new one. ordinaryStructType->insertAfter(originalStructType); - // Mark the original type for removal once all the other legalization - // activity is completed. This is necessary because both the original - // and replacement type have the same mangled name, so they would - // collide. - // - // (Also, the original type wasn't legal - that was the whole point...) - originalStructType->removeFromParent(); - context->replacedInstructions.add(originalStructType); - for(auto ee : ordinaryElements) { // We will ensure that all the original fields are represented, |
