From f96a3fea6704da866e96e453f722a951c214ba28 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 18 Mar 2024 16:41:40 -0700 Subject: Fix SPIRV for mesh shaders, checks for invalid target code&recursion. (#3788) * Fix #3780. * Fixers #3781. * Add test for #3781. * Diagnose error on unsupported builtin intrinsic types. * Add check for recursion. * Fix. * Fix. * Fix recursion detection. * Fix. * Fix. * Fix recursion logic. * More fix. --- source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-emit-spirv.cpp | 6 ++ source/slang/slang-emit.cpp | 5 +- source/slang/slang-ir-check-unsupported-inst.cpp | 71 ++++++++++++++++++++++++ source/slang/slang-ir-check-unsupported-inst.h | 10 ++++ source/slang/slang-ir-spirv-legalize.cpp | 6 +- source/slang/slang-ir-util.cpp | 20 ++++++- source/slang/slang-ir-util.h | 2 + 8 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 source/slang/slang-ir-check-unsupported-inst.cpp create mode 100644 source/slang/slang-ir-check-unsupported-inst.h (limited to 'source/slang') diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index db17c92a0..ba57e63f9 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -792,6 +792,8 @@ DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.") DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.") +DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.") +DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.") DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index afb1b3a63..6ecffecd1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3060,6 +3060,12 @@ struct SPIRVEmitContext case Stage::Callable: requireSPIRVCapability(SpvCapabilityRayTracingKHR); ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_ray_tracing")); + break; + case Stage::Mesh: + case Stage::Amplification: + requireSPIRVCapability(SpvCapabilityMeshShadingEXT); + ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_mesh_shader")); + break; default: break; } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ee38996e6..649858e51 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -8,6 +8,7 @@ #include "slang-ir-any-value-inference.h" #include "slang-ir-bind-existentials.h" #include "slang-ir-byte-address-legalize.h" +#include "slang-ir-check-unsupported-inst.h" #include "slang-ir-collect-global-uniforms.h" #include "slang-ir-cleanup-void.h" #include "slang-ir-composite-reg-to-mem.h" @@ -1061,7 +1062,9 @@ Result linkAndOptimizeIR( outLinkedIR.metadata = metadata; - return SLANG_OK; + checkUnsupportedInst(codeGenContext->getTargetReq(), irModule, sink); + + return sink->getErrorCount() == 0 ? SLANG_OK : SLANG_FAIL; } SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outArtifact) diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp new file mode 100644 index 000000000..c89928af5 --- /dev/null +++ b/source/slang/slang-ir-check-unsupported-inst.cpp @@ -0,0 +1,71 @@ +#include "slang-ir-check-unsupported-inst.h" + +#include "slang-ir.h" +#include "slang-ir-util.h" + +namespace Slang +{ + bool isCPUTarget(TargetRequest* targetReq); + + bool checkRecursionImpl(HashSet& checkedFuncs, HashSet& callStack, IRFunc* func, DiagnosticSink* sink) + { + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + auto callInst = as(inst); + if (!callInst) + continue; + auto callee = as(callInst->getCallee()); + if (!callee) + continue; + if (!callStack.add(callee)) + { + sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee); + return false; + } + if (checkedFuncs.add(callee)) + checkRecursionImpl(checkedFuncs, callStack, callee, sink); + callStack.remove(callee); + } + } + return true; + } + + void checkRecursion(HashSet& checkedFuncs, IRFunc* func, DiagnosticSink* sink) + { + HashSet callStack; + if (checkedFuncs.add(func)) + { + callStack.add(func); + checkRecursionImpl(checkedFuncs, callStack, func, sink); + } + } + + void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink) + { + HashSet checkedFuncsForRecursionDetection; + + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + { + if (!as(globalInst->getOperand(0))) + { + sink->diagnose(findFirstUseLoc(globalInst), Diagnostics::unsupportedBuiltinType, globalInst); + } + break; + } + case kIROp_Func: + if (!isCPUTarget(target)) + checkRecursion(checkedFuncsForRecursionDetection, as(globalInst), sink); + default: + break; + } + } + } + +} diff --git a/source/slang/slang-ir-check-unsupported-inst.h b/source/slang/slang-ir-check-unsupported-inst.h new file mode 100644 index 000000000..b52306566 --- /dev/null +++ b/source/slang/slang-ir-check-unsupported-inst.h @@ -0,0 +1,10 @@ +#pragma once + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + class TargetRequest; + + void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index a7c14242b..2b7e86f47 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -216,8 +216,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto user = use->getUser(); IRBuilder builder(user); builder.setInsertBefore(user); - if(as(user) || as(user)) + + if((as(user) || as(user)) && + use == user->getOperands()) { + // If the use is the address operand of a getElement or FieldExtract, + // replace the inst with the updated address and continue to follow the use chain. auto basePtrType = as(addr->getDataType()); IRType* ptrType = nullptr; if (basePtrType->hasAddressSpace()) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 2f059d308..e1eb86508 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -230,6 +230,18 @@ bool isSimpleDataType(IRType* type) } } +SourceLoc findFirstUseLoc(IRInst* inst) +{ + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (use->getUser()->sourceLoc.isValid()) + { + return use->getUser()->sourceLoc; + } + } + return inst->sourceLoc; +} + IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue) { auto outerGeneric = as(findOuterGeneric(value)); @@ -582,14 +594,20 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type) getTypeNameHint(sb, as(type)->getValueType()); break; case kIROp_VectorType: + sb << "vector<"; getTypeNameHint(sb, type->getOperand(0)); + sb << ","; getTypeNameHint(sb, as(type)->getElementCount()); + sb << ">"; break; case kIROp_MatrixType: + sb << "matrix<"; getTypeNameHint(sb, type->getOperand(0)); + sb << ","; getTypeNameHint(sb, as(type)->getRowCount()); - sb << "x"; + sb << ","; getTypeNameHint(sb, as(type)->getColumnCount()); + sb << ">"; break; case kIROp_IntLit: sb << as(type)->getValue(); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index fd34d81f7..40ba783b9 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -89,6 +89,8 @@ bool isValueType(IRInst* type); bool isSimpleDataType(IRType* type); +SourceLoc findFirstUseLoc(IRInst* inst); + inline bool isChildInstOf(IRInst* inst, IRInst* parent) { while (inst) -- cgit v1.2.3