summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-18 16:41:40 -0700
committerGitHub <noreply@github.com>2024-03-18 16:41:40 -0700
commitf96a3fea6704da866e96e453f722a951c214ba28 (patch)
treee14aafe59eca98992593803db19cc3dff0ae1fe1 /source/slang
parent7f6e95917bb1929115b4cffa2ed9035aa8710ee4 (diff)
Fix SPIRV for mesh shaders, checks for invalid target code&recursion. (#3788)
* Fix #3780. * Fixers #3781. * Add test for #3781. * Diagnose error on unsupported builtin intrinsic types. * Add check for recursion. * Fix. * Fix. * Fix recursion detection. * Fix. * Fix. * Fix recursion logic. * More fix.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-spirv.cpp6
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-check-unsupported-inst.cpp71
-rw-r--r--source/slang/slang-ir-check-unsupported-inst.h10
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp6
-rw-r--r--source/slang/slang-ir-util.cpp20
-rw-r--r--source/slang/slang-ir-util.h2
8 files changed, 119 insertions, 3 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index db17c92a0..ba57e63f9 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -792,6 +792,8 @@ DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL
DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.")
DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.")
+DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.")
+DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.")
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index afb1b3a63..6ecffecd1 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3060,6 +3060,12 @@ struct SPIRVEmitContext
case Stage::Callable:
requireSPIRVCapability(SpvCapabilityRayTracingKHR);
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_ray_tracing"));
+ break;
+ case Stage::Mesh:
+ case Stage::Amplification:
+ requireSPIRVCapability(SpvCapabilityMeshShadingEXT);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_mesh_shader"));
+ break;
default:
break;
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ee38996e6..649858e51 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -8,6 +8,7 @@
#include "slang-ir-any-value-inference.h"
#include "slang-ir-bind-existentials.h"
#include "slang-ir-byte-address-legalize.h"
+#include "slang-ir-check-unsupported-inst.h"
#include "slang-ir-collect-global-uniforms.h"
#include "slang-ir-cleanup-void.h"
#include "slang-ir-composite-reg-to-mem.h"
@@ -1061,7 +1062,9 @@ Result linkAndOptimizeIR(
outLinkedIR.metadata = metadata;
- return SLANG_OK;
+ checkUnsupportedInst(codeGenContext->getTargetReq(), irModule, sink);
+
+ return sink->getErrorCount() == 0 ? SLANG_OK : SLANG_FAIL;
}
SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outArtifact)
diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp
new file mode 100644
index 000000000..c89928af5
--- /dev/null
+++ b/source/slang/slang-ir-check-unsupported-inst.cpp
@@ -0,0 +1,71 @@
+#include "slang-ir-check-unsupported-inst.h"
+
+#include "slang-ir.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+ bool isCPUTarget(TargetRequest* targetReq);
+
+ bool checkRecursionImpl(HashSet<IRFunc*>& checkedFuncs, HashSet<IRFunc*>& callStack, IRFunc* func, DiagnosticSink* sink)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ auto callInst = as<IRCall>(inst);
+ if (!callInst)
+ continue;
+ auto callee = as<IRFunc>(callInst->getCallee());
+ if (!callee)
+ continue;
+ if (!callStack.add(callee))
+ {
+ sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
+ return false;
+ }
+ if (checkedFuncs.add(callee))
+ checkRecursionImpl(checkedFuncs, callStack, callee, sink);
+ callStack.remove(callee);
+ }
+ }
+ return true;
+ }
+
+ void checkRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
+ {
+ HashSet<IRFunc*> callStack;
+ if (checkedFuncs.add(func))
+ {
+ callStack.add(func);
+ checkRecursionImpl(checkedFuncs, callStack, func, sink);
+ }
+ }
+
+ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
+ {
+ HashSet<IRFunc*> checkedFuncsForRecursionDetection;
+
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ {
+ if (!as<IRBasicType>(globalInst->getOperand(0)))
+ {
+ sink->diagnose(findFirstUseLoc(globalInst), Diagnostics::unsupportedBuiltinType, globalInst);
+ }
+ break;
+ }
+ case kIROp_Func:
+ if (!isCPUTarget(target))
+ checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
+ default:
+ break;
+ }
+ }
+ }
+
+}
diff --git a/source/slang/slang-ir-check-unsupported-inst.h b/source/slang/slang-ir-check-unsupported-inst.h
new file mode 100644
index 000000000..b52306566
--- /dev/null
+++ b/source/slang/slang-ir-check-unsupported-inst.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+ class TargetRequest;
+
+ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink);
+}
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index a7c14242b..2b7e86f47 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -216,8 +216,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto user = use->getUser();
IRBuilder builder(user);
builder.setInsertBefore(user);
- if(as<IRGetElement>(user) || as<IRFieldExtract>(user))
+
+ if((as<IRGetElement>(user) || as<IRFieldExtract>(user)) &&
+ use == user->getOperands())
{
+ // If the use is the address operand of a getElement or FieldExtract,
+ // replace the inst with the updated address and continue to follow the use chain.
auto basePtrType = as<IRPtrTypeBase>(addr->getDataType());
IRType* ptrType = nullptr;
if (basePtrType->hasAddressSpace())
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 2f059d308..e1eb86508 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -230,6 +230,18 @@ bool isSimpleDataType(IRType* type)
}
}
+SourceLoc findFirstUseLoc(IRInst* inst)
+{
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->sourceLoc.isValid())
+ {
+ return use->getUser()->sourceLoc;
+ }
+ }
+ return inst->sourceLoc;
+}
+
IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue)
{
auto outerGeneric = as<IRGeneric>(findOuterGeneric(value));
@@ -582,14 +594,20 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type)
getTypeNameHint(sb, as<IRRateQualifiedType>(type)->getValueType());
break;
case kIROp_VectorType:
+ sb << "vector<";
getTypeNameHint(sb, type->getOperand(0));
+ sb << ",";
getTypeNameHint(sb, as<IRVectorType>(type)->getElementCount());
+ sb << ">";
break;
case kIROp_MatrixType:
+ sb << "matrix<";
getTypeNameHint(sb, type->getOperand(0));
+ sb << ",";
getTypeNameHint(sb, as<IRMatrixType>(type)->getRowCount());
- sb << "x";
+ sb << ",";
getTypeNameHint(sb, as<IRMatrixType>(type)->getColumnCount());
+ sb << ">";
break;
case kIROp_IntLit:
sb << as<IRIntLit>(type)->getValue();
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index fd34d81f7..40ba783b9 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -89,6 +89,8 @@ bool isValueType(IRInst* type);
bool isSimpleDataType(IRType* type);
+SourceLoc findFirstUseLoc(IRInst* inst);
+
inline bool isChildInstOf(IRInst* inst, IRInst* parent)
{
while (inst)