summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-17 15:57:22 -0700
committerGitHub <noreply@github.com>2023-03-17 15:57:22 -0700
commit7f11f883d0781952f002b3aa3222a3aa0040f18a (patch)
tree08eaf10fef39211fbc3f124679bfe8a35775a5a7 /source/slang/slang-check-expr.cpp
parent4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff)
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions. * Update test. * Fix cuda preamble emit. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp43
1 files changed, 40 insertions, 3 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f749361d7..8d8a72dd6 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2333,11 +2333,12 @@ namespace Slang
}
};
- struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions
+ template<typename ExprASTType>
+ struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions
{
virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
- return semantics->getASTBuilder()->create<PrimalSubstituteExpr>();
+ return semantics->getASTBuilder()->create<ExprASTType>();
}
void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
@@ -2431,7 +2432,43 @@ namespace Slang
Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
{
- PrimalSubstituteExprCheckingActions actions;
+ PassthroughHighOrderExprCheckingActionsBase<PrimalSubstituteExpr> actions;
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
+ }
+
+ Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr)
+ {
+ auto isInt3Type = [this](Type* type)
+ {
+ auto vectorType = as<VectorExpressionType>(type);
+ if (!vectorType)
+ return false;
+ if (!isIntegerBaseType(getVectorBaseType(vectorType)))
+ return false;
+ auto constElementCount = as<ConstantIntVal>(vectorType->elementCount);
+ if (!constElementCount)
+ return false;
+ return constElementCount->value == 3;
+ };
+ expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this);
+ if (!isInt3Type(expr->threadGroupSize->type.type))
+ {
+ getSink()->diagnose(
+ expr->threadGroupSize,
+ Diagnostics::typeMismatch,
+ "uint3",
+ expr->threadGroupSize->type);
+ }
+ expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this);
+ if (!isInt3Type(expr->dispatchSize->type.type))
+ {
+ getSink()->diagnose(
+ expr->dispatchSize,
+ Diagnostics::typeMismatch,
+ "uint3",
+ expr->dispatchSize->type);
+ }
+ PassthroughHighOrderExprCheckingActionsBase<DispatchKernelExpr> actions;
return _checkHigherOrderInvokeExpr(this, expr, &actions);
}