summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-17 15:57:22 -0700
committerGitHub <noreply@github.com>2023-03-17 15:57:22 -0700
commit7f11f883d0781952f002b3aa3222a3aa0040f18a (patch)
tree08eaf10fef39211fbc3f124679bfe8a35775a5a7
parent4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff)
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions. * Update test. * Fix cuda preamble emit. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--prelude/slang-cuda-prelude.h20
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/slang-ast-expr.h10
-rw-r--r--source/slang/slang-ast-modifier.h10
-rw-r--r--source/slang/slang-check-expr.cpp43
-rw-r--r--source/slang/slang-check-impl.h1
-rw-r--r--source/slang/slang-emit-cpp.h2
-rw-r--r--source/slang/slang-emit-cuda.cpp32
-rw-r--r--source/slang/slang-emit-cuda.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h5
-rw-r--r--source/slang/slang-ir-insts.h31
-rw-r--r--source/slang/slang-ir.cpp14
-rw-r--r--source/slang/slang-lower-to-ir.cpp50
-rw-r--r--source/slang/slang-options.cpp6
-rw-r--r--source/slang/slang-parser.cpp24
-rw-r--r--tests/autodiff/cuda-kernel-export.slang14
16 files changed, 257 insertions, 13 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index cb1bb188b..7a4c5a918 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -1,3 +1,11 @@
+#define SLANG_PRELUDE_EXPORT
+
+#ifdef __CUDACC_RTC__
+#define SLANG_CUDA_RTC 1
+#else
+#define SLANG_CUDA_RTC 0
+#endif
+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
// For this to work NVRTC needs to have the path to the CUDA SDK.
//
@@ -341,6 +349,7 @@ SLANG_CUDA_VECTOR_FLOAT_OPS(__half)
SLANG_CUDA_FLOAT_VECTOR_MOD(float)
SLANG_CUDA_FLOAT_VECTOR_MOD(double)
+#if SLANG_CUDA_RTC
#define SLANG_MAKE_VECTOR(T) \
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) { return T##2{x, y}; }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) { return T##3{ x, y, z }; }\
@@ -355,15 +364,24 @@ SLANG_MAKE_VECTOR(float)
SLANG_MAKE_VECTOR(double)
SLANG_MAKE_VECTOR(longlong)
SLANG_MAKE_VECTOR(ulonglong)
+#endif
+
#if SLANG_CUDA_ENABLE_HALF
SLANG_MAKE_VECTOR(__half)
#endif
+#if SLANG_CUDA_RTC
#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); }
+#else
+#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); }
+#endif
SLANG_MAKE_VECTOR_FROM_SCALAR(int)
SLANG_MAKE_VECTOR_FROM_SCALAR(uint)
SLANG_MAKE_VECTOR_FROM_SCALAR(short)
@@ -378,10 +396,12 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double)
SLANG_MAKE_VECTOR_FROM_SCALAR(__half)
#endif
+
template<typename T, int n>
struct GetVectorTypeImpl {};
#define GET_VECTOR_TYPE_IMPL(T, n)\
+template<>\
struct GetVectorTypeImpl<T,n>\
{\
typedef T##n type;\
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 790aa3d55..6581cc605 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3078,6 +3078,12 @@ attribute_syntax [DllExport] : DllExportAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax [CudaHost] : CudaHostAttribute;
+
+__attributeTarget(FuncDecl)
+attribute_syntax [CudaKernel] : CudaKernelAttribute;
+
__attributeTarget(InterfaceDecl)
attribute_syntax [COM(guid: String)] : ComInterfaceAttribute;
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 301dded49..0a875fb50 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -477,6 +477,16 @@ class BackwardDifferentiateExpr: public DifferentiateExpr
SLANG_AST_CLASS(BackwardDifferentiateExpr)
};
+ /// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to
+ /// dispatch a compute kernel from host.
+ ///
+class DispatchKernelExpr : public HigherOrderInvokeExpr
+{
+ SLANG_AST_CLASS(DispatchKernelExpr)
+ Expr* threadGroupSize;
+ Expr* dispatchSize;
+};
+
/// An express to mark its inner expression as an intended non-differential call.
class TreatAsDifferentiableExpr : public Expr
{
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index c58b7de21..26303d6ad 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1068,6 +1068,16 @@ class CudaDeviceExportAttribute : public Attribute
SLANG_AST_CLASS(CudaDeviceExportAttribute)
};
+class CudaKernelAttribute : public Attribute
+{
+ SLANG_AST_CLASS(CudaKernelAttribute)
+};
+
+class CudaHostAttribute : public Attribute
+{
+ SLANG_AST_CLASS(CudaHostAttribute)
+};
+
class DerivativeMemberAttribute : public Attribute
{
SLANG_AST_CLASS(DerivativeMemberAttribute)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f749361d7..8d8a72dd6 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2333,11 +2333,12 @@ namespace Slang
}
};
- struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions
+ template<typename ExprASTType>
+ struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions
{
virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
- return semantics->getASTBuilder()->create<PrimalSubstituteExpr>();
+ return semantics->getASTBuilder()->create<ExprASTType>();
}
void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
@@ -2431,7 +2432,43 @@ namespace Slang
Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
{
- PrimalSubstituteExprCheckingActions actions;
+ PassthroughHighOrderExprCheckingActionsBase<PrimalSubstituteExpr> actions;
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
+ }
+
+ Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr)
+ {
+ auto isInt3Type = [this](Type* type)
+ {
+ auto vectorType = as<VectorExpressionType>(type);
+ if (!vectorType)
+ return false;
+ if (!isIntegerBaseType(getVectorBaseType(vectorType)))
+ return false;
+ auto constElementCount = as<ConstantIntVal>(vectorType->elementCount);
+ if (!constElementCount)
+ return false;
+ return constElementCount->value == 3;
+ };
+ expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this);
+ if (!isInt3Type(expr->threadGroupSize->type.type))
+ {
+ getSink()->diagnose(
+ expr->threadGroupSize,
+ Diagnostics::typeMismatch,
+ "uint3",
+ expr->threadGroupSize->type);
+ }
+ expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this);
+ if (!isInt3Type(expr->dispatchSize->type.type))
+ {
+ getSink()->diagnose(
+ expr->dispatchSize,
+ Diagnostics::typeMismatch,
+ "uint3",
+ expr->dispatchSize->type);
+ }
+ PassthroughHighOrderExprCheckingActionsBase<DispatchKernelExpr> actions;
return _checkHigherOrderInvokeExpr(this, expr, &actions);
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 3d40b10e9..4181ca43b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1953,6 +1953,7 @@ namespace Slang
Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr);
Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr);
+ Expr* visitDispatchKernelExpr(DispatchKernelExpr* expr);
Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr);
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 71c382f87..89978e68a 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -105,7 +105,7 @@ protected:
/// Maybe emits 'export' (such that visible outside binary/dll) and `extern "C"` naming
void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC);
- void _maybeEmitExportLike(IRInst* inst);
+ virtual void _maybeEmitExportLike(IRInst* inst);
static bool _isVariable(IROp op);
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index a151ab0e2..846b3b1f2 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -223,9 +223,21 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst)
{
- if(inst && inst->findDecoration<IREntryPointDecoration>())
+ if (!inst)
+ return;
+ if (inst->findDecoration<IREntryPointDecoration>())
{
m_writer->emit("extern \"C\" __global__ ");
+ return;
+ }
+
+ if (inst->findDecoration<IRCudaKernelDecoration>())
+ {
+ m_writer->emit("__global__ ");
+ }
+ else if (inst->findDecoration<IRCudaHostDecoration>())
+ {
+ m_writer->emit("__host__ ");
}
else
{
@@ -608,6 +620,24 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit(")optixGetSbtDataPointer())");
return true;
}
+ case kIROp_DispatchKernel:
+ {
+ auto dispatchInst = as<IRDispatchKernel>(inst);
+ emitOperand(dispatchInst->getBaseFn(), getInfo(EmitOp::Atomic));
+ m_writer->emit("<<<");
+ emitOperand(dispatchInst->getThreadGroupSize(), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(dispatchInst->getDispatchSize(), getInfo(EmitOp::General));
+ m_writer->emit(">>>(");
+ for (UInt i = 0; i < dispatchInst->getArgCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(dispatchInst->getArg(i), getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
+ }
default: break;
}
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index 8a907dc7c..2ba7dd6a3 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -79,6 +79,8 @@ protected:
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE;
+ virtual void _maybeEmitExportLike(IRInst* inst) SLANG_OVERRIDE { SLANG_UNUSED(inst); }
+
virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 4516a6bc3..04e08293f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -667,6 +667,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// An `[entryPoint]` decoration marks a function that represents a shader entry point
INST(EntryPointDecoration, entryPoint, 2, 0)
+ INST(CudaKernelDecoration, CudaKernel, 0, 0)
+ INST(CudaHostDecoration, CudaHost, 0, 0)
+
/// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point.
INST(EntryPointParamDecoration, entryPointParam, 0, 0)
@@ -904,6 +907,8 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
INST(PrimalSubstitute, PrimalSubstitute, 1, 0)
+INST(DispatchKernel, DispatchKernel, 3, 0)
+
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index cf58c22d0..cf66f1f6b 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -421,6 +421,9 @@ struct IREntryPointDecoration : IRDecoration
IRStringLit* getModuleName() { return cast<IRStringLit>(getOperand(2)); }
};
+IR_SIMPLE_DECORATION(CudaHostDecoration)
+IR_SIMPLE_DECORATION(CudaKernelDecoration)
+
struct IRGeometryInputPrimitiveTypeDecoration: IRDecoration
{
IR_PARENT_ISA(GeometryInputPrimitiveTypeDecoration)
@@ -913,13 +916,26 @@ struct IRPrimalSubstitute : IRInst
{
kOp = kIROp_PrimalSubstitute
};
- // The base function for the call.
- IRUse base;
IRInst* getBaseFn() { return getOperand(0); }
IR_LEAF_ISA(PrimalSubstitute)
};
+struct IRDispatchKernel : IRInst
+{
+ enum
+ {
+ kOp = kIROp_DispatchKernel
+ };
+ IRInst* getBaseFn() { return getOperand(0); }
+ IRInst* getThreadGroupSize() { return getOperand(1); }
+ IRInst* getDispatchSize() { return getOperand(2); }
+ UInt getArgCount() { return getOperandCount() - 3; }
+ IRInst* getArg(UInt i) { return getOperand(3 + i); }
+
+ IR_LEAF_ISA(DispatchKernel)
+};
+
// Dictionary item mapping a type with a corresponding
// IDifferentiable witness table
//
@@ -2880,6 +2896,7 @@ public:
IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn);
IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn);
IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn);
+ IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential);
@@ -3773,6 +3790,16 @@ public:
addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName));
}
+ void addCudaHostDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_CudaHostDecoration);
+ }
+
+ void addCudaKernelDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_CudaKernelDecoration);
+ }
+
void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName)
{
IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) };
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9f877969a..206d73e3f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3159,6 +3159,20 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs)
+ {
+ List<IRInst*> args = {baseFn, threadGroupSize, dispatchSize};
+ args.addRange(inArgs, (Index)argCount);
+ auto inst = createInst<IRDispatchKernel>(
+ this,
+ kIROp_DispatchKernel,
+ type,
+ (Int)args.getCount(),
+ args.getBuffer());
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn)
{
auto inst = createInst<IRBackwardDifferentiatePrimal>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e6b6b5c61..8164723b6 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -643,8 +643,29 @@ LoweredValInfo emitCallToVal(
switch (tryEnv.clauseType)
{
case TryClauseType::None:
- return LoweredValInfo::simple(
- builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args));
+ {
+ auto callee = getSimpleVal(context, funcVal);
+ if (auto dispatchKernel = as<IRDispatchKernel>(callee))
+ {
+ // If callee is a dispatch kernel expr, don't emit call(dispatchKernel, ...), instead
+ // emit a dispatchKernel(high_order_args, actual_args).
+ auto result = LoweredValInfo::simple(builder->emitDispatchKernelInst(
+ type,
+ dispatchKernel->getBaseFn(),
+ dispatchKernel->getThreadGroupSize(),
+ dispatchKernel->getDispatchSize(),
+ argCount,
+ args));
+ SLANG_ASSERT(!dispatchKernel->hasUses());
+ dispatchKernel->removeAndDeallocate();
+ return result;
+ }
+ else
+ {
+ return LoweredValInfo::simple(
+ builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args));
+ }
+ }
case TryClauseType::Standard:
{
@@ -1160,6 +1181,14 @@ static void addLinkageDecoration(
builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addPublicDecoration(inst);
}
+ if (decl->findModifier<CudaHostAttribute>())
+ {
+ builder->addCudaHostDecoration(inst);
+ }
+ if (decl->findModifier<CudaKernelAttribute>())
+ {
+ builder->addCudaKernelDecoration(inst);
+ }
}
static void addLinkageDecoration(
@@ -3204,6 +3233,23 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitDispatchKernelExpr(DispatchKernelExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFunction);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+ auto threadSize = lowerRValueExpr(context, expr->threadGroupSize);
+ auto groupSize = lowerRValueExpr(context, expr->dispatchSize);
+ // Actual arguments to be filled in when we lower the actual call expr.
+ // This is handled in `emitCallToVal`.
+ return LoweredValInfo::simple(getBuilder()->emitDispatchKernelInst(
+ lowerType(context, expr->type),
+ baseVal.val,
+ getSimpleVal(context, threadSize),
+ getSimpleVal(context, groupSize),
+ 0,
+ nullptr));
+ }
+
LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
auto baseVal = lowerSubExpr(expr->arrayExpr);
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index f6932b9e0..714e2c99d 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -2189,10 +2189,12 @@ struct OptionsParser
}
// If we don't have any raw outputs but do have a raw target,
- // and output type is callable, add an empty' rawOutput.
+ // add an empty' rawOutput for certain targets where the expected behavior is obvious.
if (rawOutputs.getCount() == 0 &&
rawTargets.getCount() == 1 &&
- ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable)
+ (rawTargets[0].format == CodeGenTarget::HostCPPSource ||
+ rawTargets[0].format == CodeGenTarget::CUDASource ||
+ ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable))
{
RawOutput rawOutput;
rawOutput.impliedFormat = rawTargets[0].format;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 1bddfb9cf..cdeb0b259 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2158,6 +2158,27 @@ namespace Slang
return parseBackwardDifferentiate(parser);
}
+ static Expr* parseDispatchKernel(Parser* parser)
+ {
+ DispatchKernelExpr* dispatchExpr = parser->astBuilder->create<DispatchKernelExpr>();
+
+ parser->ReadToken(TokenType::LParent);
+
+ dispatchExpr->baseFunction = parser->ParseArgExpr();
+ parser->ReadToken(TokenType::Comma);
+ dispatchExpr->threadGroupSize = parser->ParseArgExpr();
+ parser->ReadToken(TokenType::Comma);
+ dispatchExpr->dispatchSize = parser->ParseArgExpr();
+ parser->ReadToken(TokenType::RParent);
+
+ return dispatchExpr;
+ }
+
+ static NodeBase* parseDispatchKernel(Parser* parser, void* /* unused */)
+ {
+ return parseDispatchKernel(parser);
+ }
+
/// Parse a `This` type expression
static Expr* parseThisTypeExpr(Parser* parser)
{
@@ -6721,7 +6742,8 @@ namespace Slang
_makeParseExpr("no_diff", parseTreatAsDifferentiableExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
_makeParseExpr("__fwd_diff", parseForwardDifferentiate),
- _makeParseExpr("__bwd_diff", parseBackwardDifferentiate)
+ _makeParseExpr("__bwd_diff", parseBackwardDifferentiate),
+ _makeParseExpr("__dispatch_kernel", parseDispatchKernel)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang
index 0db4d8cea..54442498b 100644
--- a/tests/autodiff/cuda-kernel-export.slang
+++ b/tests/autodiff/cuda-kernel-export.slang
@@ -1,4 +1,4 @@
-//DISABLED_TEST:SIMPLE: -target cuda -line-directive-mode none
+//DISABLE_TEST:SIMPLE: -target cuda -line-directive-mode none
// Verify that we can output a cuda device function with [CudaDeviceExport].
// Disabled until we have FileCheck.
@@ -27,3 +27,15 @@ void diffF(inout DifferentialPair<MixedType> m, float dout)
{
__bwd_diff(f)(m, dout);
}
+
+[CudaKernel]
+void myKernel(float* inValues, float* outValues)
+{
+ outValues[0] = sin(inValues[0]);
+}
+
+[CudaHost]
+public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize)
+{
+ __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues);
+} \ No newline at end of file