summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp36
1 files changed, 36 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0f9da12c4..921bd38e9 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7700,6 +7700,16 @@ namespace Slang
}
}
}
+
+ // If this method is intended to be a CUDA kernel, verify that the return type is void.
+ if (decl->findModifier<CudaKernelAttribute>())
+ {
+ if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType()))
+ {
+ getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid);
+ }
+ }
+
checkVisibility(decl);
}
@@ -9547,6 +9557,30 @@ namespace Slang
checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
+ static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*)
+ {
+ // If the method is also marked differentiable, check that the data types are either non-differentiable
+ // or marked with no_diff.
+ //
+ // Note: This is a temporary restriction until we have a more complete story for differentiability.
+ //
+ if (funcDecl->findModifier<DifferentiableAttribute>())
+ {
+ for (auto paramDecl : funcDecl->getParameters())
+ {
+ auto paramType = paramDecl->type;
+
+ if (visitor->isTypeDifferentiable(paramType))
+ {
+ if (!paramDecl->hasModifier<NoDiffModifier>())
+ {
+ visitor->getSink()->diagnose(paramDecl, Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams);
+ }
+ }
+ }
+ }
+ }
+
template<typename TDerivativeAttr, typename TDerivativeOfAttr>
bool tryCheckDerivativeOfAttributeImpl(
SemanticsVisitor* visitor,
@@ -9747,6 +9781,8 @@ namespace Slang
checkDerivativeAttribute(this, decl, bwdDerivativeAttr);
else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr))
checkDerivativeAttribute(this, decl, primalAttr);
+ else if (auto cudaKernelAttr = as<CudaKernelAttribute>(attr))
+ checkCudaKernelAttribute(this, decl, cudaKernelAttr);
}
}