diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-17 15:57:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-17 15:57:22 -0700 |
| commit | 7f11f883d0781952f002b3aa3222a3aa0040f18a (patch) | |
| tree | 08eaf10fef39211fbc3f124679bfe8a35775a5a7 /source/slang/slang-check-expr.cpp | |
| parent | 4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff) | |
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions.
* Update test.
* Fix cuda preamble emit.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 43 |
1 files changed, 40 insertions, 3 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f749361d7..8d8a72dd6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2333,11 +2333,12 @@ namespace Slang } }; - struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions + template<typename ExprASTType> + struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions { virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { - return semantics->getASTBuilder()->create<PrimalSubstituteExpr>(); + return semantics->getASTBuilder()->create<ExprASTType>(); } void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { @@ -2431,7 +2432,43 @@ namespace Slang Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) { - PrimalSubstituteExprCheckingActions actions; + PassthroughHighOrderExprCheckingActionsBase<PrimalSubstituteExpr> actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr) + { + auto isInt3Type = [this](Type* type) + { + auto vectorType = as<VectorExpressionType>(type); + if (!vectorType) + return false; + if (!isIntegerBaseType(getVectorBaseType(vectorType))) + return false; + auto constElementCount = as<ConstantIntVal>(vectorType->elementCount); + if (!constElementCount) + return false; + return constElementCount->value == 3; + }; + expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); + if (!isInt3Type(expr->threadGroupSize->type.type)) + { + getSink()->diagnose( + expr->threadGroupSize, + Diagnostics::typeMismatch, + "uint3", + expr->threadGroupSize->type); + } + expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this); + if (!isInt3Type(expr->dispatchSize->type.type)) + { + getSink()->diagnose( + expr->dispatchSize, + Diagnostics::typeMismatch, + "uint3", + expr->dispatchSize->type); + } + PassthroughHighOrderExprCheckingActionsBase<DispatchKernelExpr> actions; return _checkHigherOrderInvokeExpr(this, expr, &actions); } |
