diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-17 15:57:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-17 15:57:22 -0700 |
| commit | 7f11f883d0781952f002b3aa3222a3aa0040f18a (patch) | |
| tree | 08eaf10fef39211fbc3f124679bfe8a35775a5a7 | |
| parent | 4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff) | |
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions.
* Update test.
* Fix cuda preamble emit.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 20 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 43 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 32 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 31 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-options.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 24 | ||||
| -rw-r--r-- | tests/autodiff/cuda-kernel-export.slang | 14 |
16 files changed, 257 insertions, 13 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index cb1bb188b..7a4c5a918 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -1,3 +1,11 @@ +#define SLANG_PRELUDE_EXPORT + +#ifdef __CUDACC_RTC__ +#define SLANG_CUDA_RTC 1 +#else +#define SLANG_CUDA_RTC 0 +#endif + // Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support. // For this to work NVRTC needs to have the path to the CUDA SDK. // @@ -341,6 +349,7 @@ SLANG_CUDA_VECTOR_FLOAT_OPS(__half) SLANG_CUDA_FLOAT_VECTOR_MOD(float) SLANG_CUDA_FLOAT_VECTOR_MOD(double) +#if SLANG_CUDA_RTC #define SLANG_MAKE_VECTOR(T) \ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) { return T##2{x, y}; }\ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) { return T##3{ x, y, z }; }\ @@ -355,15 +364,24 @@ SLANG_MAKE_VECTOR(float) SLANG_MAKE_VECTOR(double) SLANG_MAKE_VECTOR(longlong) SLANG_MAKE_VECTOR(ulonglong) +#endif + #if SLANG_CUDA_ENABLE_HALF SLANG_MAKE_VECTOR(__half) #endif +#if SLANG_CUDA_RTC #define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); } +#else +#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); } +#endif SLANG_MAKE_VECTOR_FROM_SCALAR(int) SLANG_MAKE_VECTOR_FROM_SCALAR(uint) SLANG_MAKE_VECTOR_FROM_SCALAR(short) @@ -378,10 +396,12 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double) SLANG_MAKE_VECTOR_FROM_SCALAR(__half) #endif + template<typename T, int n> struct GetVectorTypeImpl {}; #define GET_VECTOR_TYPE_IMPL(T, n)\ +template<>\ struct GetVectorTypeImpl<T,n>\ {\ typedef T##n type;\ diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 790aa3d55..6581cc605 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3078,6 +3078,12 @@ attribute_syntax [DllExport] : DllExportAttribute; __attributeTarget(FuncDecl) attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute; +__attributeTarget(FuncDecl) +attribute_syntax [CudaHost] : CudaHostAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [CudaKernel] : CudaKernelAttribute; + __attributeTarget(InterfaceDecl) attribute_syntax [COM(guid: String)] : ComInterfaceAttribute; diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 301dded49..0a875fb50 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -477,6 +477,16 @@ class BackwardDifferentiateExpr: public DifferentiateExpr SLANG_AST_CLASS(BackwardDifferentiateExpr) }; + /// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to + /// dispatch a compute kernel from host. + /// +class DispatchKernelExpr : public HigherOrderInvokeExpr +{ + SLANG_AST_CLASS(DispatchKernelExpr) + Expr* threadGroupSize; + Expr* dispatchSize; +}; + /// An express to mark its inner expression as an intended non-differential call. class TreatAsDifferentiableExpr : public Expr { diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index c58b7de21..26303d6ad 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1068,6 +1068,16 @@ class CudaDeviceExportAttribute : public Attribute SLANG_AST_CLASS(CudaDeviceExportAttribute) }; +class CudaKernelAttribute : public Attribute +{ + SLANG_AST_CLASS(CudaKernelAttribute) +}; + +class CudaHostAttribute : public Attribute +{ + SLANG_AST_CLASS(CudaHostAttribute) +}; + class DerivativeMemberAttribute : public Attribute { SLANG_AST_CLASS(DerivativeMemberAttribute) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f749361d7..8d8a72dd6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2333,11 +2333,12 @@ namespace Slang } }; - struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions + template<typename ExprASTType> + struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions { virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { - return semantics->getASTBuilder()->create<PrimalSubstituteExpr>(); + return semantics->getASTBuilder()->create<ExprASTType>(); } void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { @@ -2431,7 +2432,43 @@ namespace Slang Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) { - PrimalSubstituteExprCheckingActions actions; + PassthroughHighOrderExprCheckingActionsBase<PrimalSubstituteExpr> actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr) + { + auto isInt3Type = [this](Type* type) + { + auto vectorType = as<VectorExpressionType>(type); + if (!vectorType) + return false; + if (!isIntegerBaseType(getVectorBaseType(vectorType))) + return false; + auto constElementCount = as<ConstantIntVal>(vectorType->elementCount); + if (!constElementCount) + return false; + return constElementCount->value == 3; + }; + expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); + if (!isInt3Type(expr->threadGroupSize->type.type)) + { + getSink()->diagnose( + expr->threadGroupSize, + Diagnostics::typeMismatch, + "uint3", + expr->threadGroupSize->type); + } + expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this); + if (!isInt3Type(expr->dispatchSize->type.type)) + { + getSink()->diagnose( + expr->dispatchSize, + Diagnostics::typeMismatch, + "uint3", + expr->dispatchSize->type); + } + PassthroughHighOrderExprCheckingActionsBase<DispatchKernelExpr> actions; return _checkHigherOrderInvokeExpr(this, expr, &actions); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 3d40b10e9..4181ca43b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1953,6 +1953,7 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr); + Expr* visitDispatchKernelExpr(DispatchKernelExpr* expr); Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 71c382f87..89978e68a 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -105,7 +105,7 @@ protected: /// Maybe emits 'export' (such that visible outside binary/dll) and `extern "C"` naming void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC); - void _maybeEmitExportLike(IRInst* inst); + virtual void _maybeEmitExportLike(IRInst* inst); static bool _isVariable(IROp op); diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index a151ab0e2..846b3b1f2 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -223,9 +223,21 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst) { - if(inst && inst->findDecoration<IREntryPointDecoration>()) + if (!inst) + return; + if (inst->findDecoration<IREntryPointDecoration>()) { m_writer->emit("extern \"C\" __global__ "); + return; + } + + if (inst->findDecoration<IRCudaKernelDecoration>()) + { + m_writer->emit("__global__ "); + } + else if (inst->findDecoration<IRCudaHostDecoration>()) + { + m_writer->emit("__host__ "); } else { @@ -608,6 +620,24 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")optixGetSbtDataPointer())"); return true; } + case kIROp_DispatchKernel: + { + auto dispatchInst = as<IRDispatchKernel>(inst); + emitOperand(dispatchInst->getBaseFn(), getInfo(EmitOp::Atomic)); + m_writer->emit("<<<"); + emitOperand(dispatchInst->getThreadGroupSize(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(dispatchInst->getDispatchSize(), getInfo(EmitOp::General)); + m_writer->emit(">>>("); + for (UInt i = 0; i < dispatchInst->getArgCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + emitOperand(dispatchInst->getArg(i), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } default: break; } diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index 8a907dc7c..2ba7dd6a3 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -79,6 +79,8 @@ protected: virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void _maybeEmitExportLike(IRInst* inst) SLANG_OVERRIDE { SLANG_UNUSED(inst); } + virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4516a6bc3..04e08293f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -667,6 +667,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// An `[entryPoint]` decoration marks a function that represents a shader entry point INST(EntryPointDecoration, entryPoint, 2, 0) + INST(CudaKernelDecoration, CudaKernel, 0, 0) + INST(CudaHostDecoration, CudaHost, 0, 0) + /// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point. INST(EntryPointParamDecoration, entryPointParam, 0, 0) @@ -904,6 +907,8 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(PrimalSubstitute, PrimalSubstitute, 1, 0) +INST(DispatchKernel, DispatchKernel, 3, 0) + // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index cf58c22d0..cf66f1f6b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -421,6 +421,9 @@ struct IREntryPointDecoration : IRDecoration IRStringLit* getModuleName() { return cast<IRStringLit>(getOperand(2)); } }; +IR_SIMPLE_DECORATION(CudaHostDecoration) +IR_SIMPLE_DECORATION(CudaKernelDecoration) + struct IRGeometryInputPrimitiveTypeDecoration: IRDecoration { IR_PARENT_ISA(GeometryInputPrimitiveTypeDecoration) @@ -913,13 +916,26 @@ struct IRPrimalSubstitute : IRInst { kOp = kIROp_PrimalSubstitute }; - // The base function for the call. - IRUse base; IRInst* getBaseFn() { return getOperand(0); } IR_LEAF_ISA(PrimalSubstitute) }; +struct IRDispatchKernel : IRInst +{ + enum + { + kOp = kIROp_DispatchKernel + }; + IRInst* getBaseFn() { return getOperand(0); } + IRInst* getThreadGroupSize() { return getOperand(1); } + IRInst* getDispatchSize() { return getOperand(2); } + UInt getArgCount() { return getOperandCount() - 3; } + IRInst* getArg(UInt i) { return getOperand(3 + i); } + + IR_LEAF_ISA(DispatchKernel) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2880,6 +2896,7 @@ public: IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); + IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); @@ -3773,6 +3790,16 @@ public: addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName)); } + void addCudaHostDecoration(IRInst* value) + { + addDecoration(value, kIROp_CudaHostDecoration); + } + + void addCudaKernelDecoration(IRInst* value) + { + addDecoration(value, kIROp_CudaKernelDecoration); + } + void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName) { IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) }; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9f877969a..206d73e3f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3159,6 +3159,20 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs) + { + List<IRInst*> args = {baseFn, threadGroupSize, dispatchSize}; + args.addRange(inArgs, (Index)argCount); + auto inst = createInst<IRDispatchKernel>( + this, + kIROp_DispatchKernel, + type, + (Int)args.getCount(), + args.getBuffer()); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) { auto inst = createInst<IRBackwardDifferentiatePrimal>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e6b6b5c61..8164723b6 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -643,8 +643,29 @@ LoweredValInfo emitCallToVal( switch (tryEnv.clauseType) { case TryClauseType::None: - return LoweredValInfo::simple( - builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); + { + auto callee = getSimpleVal(context, funcVal); + if (auto dispatchKernel = as<IRDispatchKernel>(callee)) + { + // If callee is a dispatch kernel expr, don't emit call(dispatchKernel, ...), instead + // emit a dispatchKernel(high_order_args, actual_args). + auto result = LoweredValInfo::simple(builder->emitDispatchKernelInst( + type, + dispatchKernel->getBaseFn(), + dispatchKernel->getThreadGroupSize(), + dispatchKernel->getDispatchSize(), + argCount, + args)); + SLANG_ASSERT(!dispatchKernel->hasUses()); + dispatchKernel->removeAndDeallocate(); + return result; + } + else + { + return LoweredValInfo::simple( + builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); + } + } case TryClauseType::Standard: { @@ -1160,6 +1181,14 @@ static void addLinkageDecoration( builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice()); builder->addPublicDecoration(inst); } + if (decl->findModifier<CudaHostAttribute>()) + { + builder->addCudaHostDecoration(inst); + } + if (decl->findModifier<CudaKernelAttribute>()) + { + builder->addCudaKernelDecoration(inst); + } } static void addLinkageDecoration( @@ -3204,6 +3233,23 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + LoweredValInfo visitDispatchKernelExpr(DispatchKernelExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + auto threadSize = lowerRValueExpr(context, expr->threadGroupSize); + auto groupSize = lowerRValueExpr(context, expr->dispatchSize); + // Actual arguments to be filled in when we lower the actual call expr. + // This is handled in `emitCallToVal`. + return LoweredValInfo::simple(getBuilder()->emitDispatchKernelInst( + lowerType(context, expr->type), + baseVal.val, + getSimpleVal(context, threadSize), + getSimpleVal(context, groupSize), + 0, + nullptr)); + } + LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { auto baseVal = lowerSubExpr(expr->arrayExpr); diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index f6932b9e0..714e2c99d 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2189,10 +2189,12 @@ struct OptionsParser } // If we don't have any raw outputs but do have a raw target, - // and output type is callable, add an empty' rawOutput. + // add an empty' rawOutput for certain targets where the expected behavior is obvious. if (rawOutputs.getCount() == 0 && rawTargets.getCount() == 1 && - ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable) + (rawTargets[0].format == CodeGenTarget::HostCPPSource || + rawTargets[0].format == CodeGenTarget::CUDASource || + ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable)) { RawOutput rawOutput; rawOutput.impliedFormat = rawTargets[0].format; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1bddfb9cf..cdeb0b259 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2158,6 +2158,27 @@ namespace Slang return parseBackwardDifferentiate(parser); } + static Expr* parseDispatchKernel(Parser* parser) + { + DispatchKernelExpr* dispatchExpr = parser->astBuilder->create<DispatchKernelExpr>(); + + parser->ReadToken(TokenType::LParent); + + dispatchExpr->baseFunction = parser->ParseArgExpr(); + parser->ReadToken(TokenType::Comma); + dispatchExpr->threadGroupSize = parser->ParseArgExpr(); + parser->ReadToken(TokenType::Comma); + dispatchExpr->dispatchSize = parser->ParseArgExpr(); + parser->ReadToken(TokenType::RParent); + + return dispatchExpr; + } + + static NodeBase* parseDispatchKernel(Parser* parser, void* /* unused */) + { + return parseDispatchKernel(parser); + } + /// Parse a `This` type expression static Expr* parseThisTypeExpr(Parser* parser) { @@ -6721,7 +6742,8 @@ namespace Slang _makeParseExpr("no_diff", parseTreatAsDifferentiableExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), _makeParseExpr("__fwd_diff", parseForwardDifferentiate), - _makeParseExpr("__bwd_diff", parseBackwardDifferentiate) + _makeParseExpr("__bwd_diff", parseBackwardDifferentiate), + _makeParseExpr("__dispatch_kernel", parseDispatchKernel) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang index 0db4d8cea..54442498b 100644 --- a/tests/autodiff/cuda-kernel-export.slang +++ b/tests/autodiff/cuda-kernel-export.slang @@ -1,4 +1,4 @@ -//DISABLED_TEST:SIMPLE: -target cuda -line-directive-mode none +//DISABLE_TEST:SIMPLE: -target cuda -line-directive-mode none // Verify that we can output a cuda device function with [CudaDeviceExport]. // Disabled until we have FileCheck. @@ -27,3 +27,15 @@ void diffF(inout DifferentialPair<MixedType> m, float dout) { __bwd_diff(f)(m, dout); } + +[CudaKernel] +void myKernel(float* inValues, float* outValues) +{ + outValues[0] = sin(inValues[0]); +} + +[CudaHost] +public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize) +{ + __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues); +}
\ No newline at end of file |
