diff options
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 673 |
1 files changed, 653 insertions, 20 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index d59d57474..c0adef436 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -2,11 +2,14 @@ #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-diagnostics.h" +#include "slang-ir-autodiff.h" namespace Slang { // Convert a type to a target tuple type. -static IRType* translateToTupleType(IRBuilder& builder, IRType* type) +static IRType* translateToTupleType( + IRBuilder& builder, + IRType* type) { if (as<IRVoidType>(type)) return type; @@ -312,34 +315,490 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } +IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr) +{ + if (as<IRBasicType>(type)) + return type; + + switch (type->getOp()) + { + case kIROp_TensorViewType: + return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); + + case kIROp_StructType: + { + // Create a new struct type with translated fields. + List<IRType*> fieldTypes; + for (auto field : as<IRStructType>(type)->getFields()) + { + fieldTypes.add(translateToHostType(builder, field->getFieldType())); + } + auto hostStructType = builder->createStructType(); + + // Add fields to the struct. + for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++) + { + builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]); + } + + return hostStructType; + } + default: + break; + } + + if (sink) + sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type); + return nullptr; +} + +IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) +{ + if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType)) + return inst; + + switch (cudaType->getOp()) + { + case kIROp_TensorViewType: + return builder->emitMakeTensorView(cudaType, inst); + + case kIROp_StructType: + { + auto cudaStructType = cast<IRStructType>(cudaType); + auto hostStructType = cast<IRStructType>(hostType); + + List<IRStructField*> cudaFields; + for (auto field : cudaStructType->getFields()) + cudaFields.add(field); + + List<IRStructField*> hostFields; + for (auto field : hostStructType->getFields()) + hostFields.add(field); + + List<IRInst*> resultFields; + for (auto ii = 0; ii < cudaFields.getCount(); ii++) + { + auto cudaField = cudaFields[ii]; + auto hostField = hostFields[ii]; + auto cudaFieldType = cudaField->getFieldType(); + auto hostFieldType = hostField->getFieldType(); + auto castedField = castHostToCUDAType( + builder, + hostFieldType, + cudaFieldType, + builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())); + + SLANG_RELEASE_ASSERT(castedField); + resultFields.add(castedField); + } + + return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer()); + } + + default: + break; + } + + // If translateToHostType worked correctly, we shouldn't get here. + SLANG_UNREACHABLE("unhandled type"); +} + +void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc) +{ + // Given a func with torch binding, we'll generate a reflection function that returns + // a tuple where the first element is another tuple of parameter names, the second + // element is a string containing the name of the fwd-diff function, and the third + // element is a string containing the name of the bwd-diff function. + // + + // Create a new function. + auto reflectionFunc = builder->createFunc(); + builder->setInsertInto(reflectionFunc); + builder->emitBlock(); + + // Go through func & generate a tuple of parameter names. + List<IRInst*> paramNames; + List<IRInst*> paramTypeNames; + UIndex paramCount = 0; + for (auto param : hostFunc->getFirstBlock()->getParams()) + { + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + { + paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName()))); + } + else + { + StringBuilder argNameBuilder; + argNameBuilder << "param"; + argNameBuilder << paramCount; + + paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice()))); + } + paramCount++; + } + + for (auto param : kernelFunc->getParams()) + { + // Check for py-export decoration. + if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>()) + { + paramTypeNames.add( + builder->emitGetNativeString( + builder->getStringValue( + pyExportHint->getExportName()))); + } + else + { + paramTypeNames.add( + builder->emitGetNativeString( + builder->getStringValue( + UnownedStringSlice("")))); + } + } + + // Create a target-tuple-type for the names + auto paramNamesTupleType = builder->getTargetTupleType( + (UInt)paramNames.getCount(), + List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer()); + auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer()); + + // Create a target-tuple-type for the type names + auto paramTypeNamesTupleType = builder->getTargetTupleType( + (UInt)paramTypeNames.getCount(), + List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer()); + auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer()); + + // Find the fwd-diff function name (blank string indicates no fwd-diff) + IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice("")); + if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>()) + { + auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc(); + + if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>()) + { + fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName())); + } + } + + // Find the bwd-diff function name (blank string indicates no bwd-diff) + IRInst* bwdDiffName = builder->getStringValue(UnownedStringSlice("")); + if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>()) + { + auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc(); + + if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>()) + { + bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName())); + } + } + + auto stringType = builder->getNativeStringType(); + auto returnTupleType = builder->getTargetTupleType( + 4, + List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer()); + + // Create a target-tuple-type for the names + auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName ); + auto returnTuple = builder->emitMakeTargetTuple( + returnTupleType, + returnTupleArgs.getCount(), + returnTupleArgs.getBuffer()); + builder->emitReturn(returnTuple); + + // Set function type. + auto funcType = builder->getFuncType(List<IRType*>(), returnTupleType); + reflectionFunc->setFullType(funcType); + + // Set function name. + StringBuilder reflFuncExportName; + auto hostFuncExportName = hostFunc->findDecoration<IRExternCppDecoration>()->getName(); + reflFuncExportName << "__funcinfo__" << hostFuncExportName; + + builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); + builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); + builder->addPublicDecoration(reflectionFunc); + builder->addKeepAliveDecoration(reflectionFunc); +} + +IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) +{ + auto typeMap = [&](IRType* t) -> IRType* { + if (auto tensorViewType = as<IRTensorViewType>(t)) + return builder->getTorchTensorType(tensorViewType->getElementType()); + }; + + auto type = translateToHostType(builder, param->getDataType(), sink); + if (outType) + *outType = type; + auto hostParam = builder->emitParam(type); + // Add a namehint to the param by appending the suffix "_host". + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + { + builder->addNameHintDecoration(hostParam, nameHint->getName()); + } + + // Then cast the param to the appropriate type. + if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam)) + return castedParam; + + return nullptr; +} + +void markTypeForPyExport(IRType* type, DiagnosticSink* sink) +{ + // If it's a basic type, we're done. + if (as<IRBasicType>(type) || as<IRVoidType>(type)) + return; + + // If it's a struct type, mark for py-export. + if (auto structType = as<IRStructType>(type)) + { + IRBuilder builder(structType->getModule()); + + // If it already has a py-export decoration, we're done. + if (!structType->findDecoration<IRPyExportDecoration>()) + { + // Look for a name hint. + UnownedStringSlice nameHint; + if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>()) + nameHint = nameHintDecoration->getName(); + else + { + // If there's no name hint, we can't export this type. + SLANG_UNEXPECTED("struct marked for export has no name"); + } + + builder.addPyExportDecoration(structType, nameHint); + } + + for (auto field : structType->getFields()) + { + markTypeForPyExport(field->getFieldType(), sink); + } + return; + } +} + +void generateReflectionForType(IRType* type, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + // Emit a function that returns a py::list. + // The list will contain the names of all the fields of the type. + // + + // TODO: Fix this to avoid emitting the same type reflection multiple times. + if (!type->findDecoration<IRPyExportDecoration>()) + return; + + IRBuilder builder(type->getModule()); + + auto reflFunc = builder.createFunc(); + builder.setInsertInto(reflFunc); + builder.emitBlock(); + + List<IRInst*> fieldNames; + List<IRInst*> fieldTypeNames; + + switch (type->getOp()) + { + case kIROp_StructType: + { + for (auto field : as<IRStructType>(type)->getFields()) + { + auto structKey = field->getKey(); + // Look for a name hint. + if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>()) + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName()))); + else + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); + + if (!field->getFieldType()->findDecoration<IRPyExportDecoration>()) + { + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); + continue; + } + + auto fieldType = field->getFieldType(); + + fieldTypeNames.add( + builder.emitGetNativeString( + builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName()))); + } + break; + } + default: + break; + } + + auto _nameListTupleType = builder.getTargetTupleType( + (UInt)fieldNames.getCount(), + List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer()); + auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer()); + + auto _typeNameListTupleType = builder.getTargetTupleType( + (UInt)fieldTypeNames.getCount(), + List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer()); + auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer()); + + auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer()); + auto nameAndTypeTuple = builder.emitMakeTargetTuple( + _nameAndTypeTupleType, + 2, + List<IRInst*>(nameListTuple, typeNameListTuple).getBuffer()); + builder.emitReturn(nameAndTypeTuple); + + // Set function type. + auto funcType = builder.getFuncType(List<IRType*>(), _nameAndTypeTupleType); + reflFunc->setFullType(funcType); + + // Set function name. + StringBuilder reflFuncExportName; + reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName(); + + builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); + builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); + builder.addPublicDecoration(reflFunc); + builder.addKeepAliveDecoration(reflFunc); +} + +IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) +{ + // Check that the function has an auto-bind decoration + if (!func->findDecoration<IRAutoPyBindCudaDecoration>()) + return nullptr; + + // We will create a CudaHost function that will call func. + // But before that, we need to determine the type of CudaHost. + // + // To determine the type, first we will append two uint3 parameters to the function. + // with the names "__blockSize" and "__gridSize", these will serve as input block and + // grid size parameters for the launch. + // + // Then, we will go over the parameters of func, and find a host-mapping for each type + // by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and + // IRTensorViewType to IRTorchTensorType. + // + // Finally, we will create a CudaHost function and transfer the name of func over to + // the generated method. + // + // The function body will first perform any conversion logic needed to convert the + // parameters from the CudaHost types to the types of func, and then use dispatch_kernel + // to dispatch func with the given block and grid size. + // + + // Create new function. + IRBuilder builder(func->getModule()); + + auto hostFunc = builder.createFunc(); + builder.setInsertInto(hostFunc); + builder.emitBlock(); + + List<IRType*> hostParamTypes; + + // Add the two uint3 parameters + auto uint3Type = builder.getVectorType(builder.getUIntType(), 3); + + auto blockSizeParam = builder.emitParam(uint3Type); + hostParamTypes.add(uint3Type); + builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize")); + + auto gridSizeParam = builder.emitParam(uint3Type); + hostParamTypes.add(uint3Type); + builder.addNameHintDecoration(gridSizeParam, UnownedStringSlice("__gridSize")); + + List<IRInst*> mappedParams; + for (auto param : func->getFirstBlock()->getParams()) + { + IRType* hostParamType; + mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType)); + hostParamTypes.add(hostParamType); + markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type? + } + + // Dispatch the original function. + builder.emitDispatchKernelInst( + builder.getVoidType(), + func, + blockSizeParam, + gridSizeParam, + mappedParams.getCount(), + mappedParams.getBuffer()); + + builder.emitReturn(); + + IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType()); + hostFunc->setFullType(hostFuncType); + + // Add a torch entry point decoration to the host function to mark + // for further processing. + // + if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>()) + { + // Mark for further processing of torch-specific insts. + builder.addTorchEntryPointDecoration(hostFunc, pybindCudaHint->getFunctionName()); + // Mark for host-side emit logic. + builder.addCudaHostDecoration(hostFunc); + // Keep alive. This method will be accessed externally. + builder.addPublicDecoration(hostFunc); + builder.addKeepAliveDecoration(hostFunc); + } + + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + // Transfer to the host function. + builder.addExternCppDecoration(hostFunc, externCppHint->getName()); + } + + if (auto exportInfoHint = func->findDecoration<IRAutoPyBindExportInfoDecoration>()) + generateReflectionFunc(&builder, func, hostFunc); + + return hostFunc; +} + void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { List<IRFunc*> workList; List<IRFunc*> cudaKernels; + List<IRFunc*> autoBindRequests; + List<IRType*> typesToExport; for (auto globalInst : module->getGlobalInsts()) { - auto func = as<IRFunc>(globalInst); - if (!func) - continue; - if (func->findDecoration<IRTorchEntryPointDecoration>()) - { - workList.add(func); - } - else if (func->findDecoration<IRCudaKernelDecoration>()) + if (auto func = as<IRFunc>(globalInst)) { - cudaKernels.add(func); + if (func->findDecoration<IRAutoPyBindCudaDecoration>()) + { + autoBindRequests.add(func); + } + if (func->findDecoration<IRTorchEntryPointDecoration>()) + { + workList.add(func); + } + else if (func->findDecoration<IRCudaKernelDecoration>()) + { + cudaKernels.add(func); + } + else + { + // Remove all other export decorations if this is not a cuda host func. + if (auto decor = func->findDecoration<IRPublicDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRHLSLExportDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRDllExportDecoration>()) + decor->removeAndDeallocate(); + } } - else + } + + // Generate CUDA wrappers for all functions that have the auto-bind decoration. + for (auto func : autoBindRequests) + { + if (auto hostFunc = generateCUDAWrapperForFunc(func, sink)) { - // Remove all other export decorations if this is not a cuda host func. - if (auto decor = func->findDecoration<IRPublicDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRHLSLExportDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRKeepAliveDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRDllExportDecoration>()) - decor->removeAndDeallocate(); + // Add generated wrapper to worklist for python bindings. + workList.add(hostFunc); } } @@ -355,6 +814,20 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) block = nextBlock; } } + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto type = as<IRType>(globalInst)) + { + if (type->findDecoration<IRPyExportDecoration>()) + { + typesToExport.add(type); + } + } + } + + for (auto type : typesToExport) + generateReflectionForType(type, sink); } // Remove all [TorchEntryPoint] functions when emitting CUDA source. @@ -372,4 +845,164 @@ void removeTorchKernels(IRModule* module) inst->removeAndDeallocate(); } +void handleAutoBindNames(IRModule* module) +{ + // We need to rewrite extern-cpp names for functions that have an auto-bind decoration. + // since the name needs to be used for the host function. + // + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>()) + { + // Find an extern decoration on the original function, and append a prefix to the name. + if (auto externCppHint = globalInst->findDecoration<IRExternCppDecoration>()) + { + IRBuilder builder(module); + + // Change the name of the original function. + StringBuilder nameBuilder; + nameBuilder << "__kernel__" << externCppHint->getName(); + externCppHint->removeAndDeallocate(); + builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); + } + } + } +} + +void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + for (auto globalInst : module->getGlobalInsts()) + { + if (!as<IRFunc>(globalInst)) + continue; + + // Look for methods marked with auto-bind and are differentiable. + if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>()) + { + if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() || + globalInst->findDecoration<IRBackwardDifferentiableDecoration>()) + { + // We'll generate a wrapper for this method that calls fwd_diff(fn) + // but an important thing to note is that we won't actually employ the usual + // differentiable typing rules. We'll assume none of the parameters are + // differentiable & throw a warning if some are. This is because, for the auto-binding + // scenario, we expect to only see tensor types, and their differentiation is handled using + // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives + // + // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over + // the original function's signature. + // In the future, when we update the type system to be able to specify the corresponding pair type, + // we can update this logic. + // + + // Create a new wrapper function. + IRBuilder builder(module); + auto func = cast<IRFunc>(globalInst); + auto wrapperFunc = builder.createFunc(); + builder.setInsertInto(wrapperFunc); + builder.emitBlock(); + + // Clone the parameter list. + List<IRInst*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(builder.emitParam(param->getFullType())); + } + + wrapperFunc->setFullType(func->getFullType()); + + auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func); + auto fwdDiffCall = builder.emitCallInst( + func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + + builder.emitReturn(fwdDiffCall); + + // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. + if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>()) + builder.addCudaKernelDecoration(wrapperFunc); + + // Add an auto-pybind-cuda decoration to the wrapper function to further generate the + // host-side binding for the derivative kernel. + // + { + auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); + StringBuilder nameBuilder; + nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff"; + builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + // Build a name for the wrapper function: <original_name>_fwd_diff + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + StringBuilder nameBuilder; + nameBuilder << externCppHint->getName() << "_fwd_diff"; + builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + builder.addPublicDecoration(wrapperFunc); + builder.addKeepAliveDecoration(wrapperFunc); + + builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc); + } + + if (globalInst->findDecoration<IRBackwardDifferentiableDecoration>()) + { + // The reasoning for the reverse-mode is the same as the forward-mode version + // (see above) + // + + // Create a new wrapper function. + IRBuilder builder(module); + auto func = cast<IRFunc>(globalInst); + auto wrapperFunc = builder.createFunc(); + builder.setInsertInto(wrapperFunc); + builder.emitBlock(); + + // Clone the parameter list. + List<IRInst*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(builder.emitParam(param->getFullType())); + } + + wrapperFunc->setFullType(func->getFullType()); + + auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func); + auto fwdDiffCall = builder.emitCallInst( + func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + + builder.emitReturn(fwdDiffCall); + + // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. + if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>()) + builder.addCudaKernelDecoration(wrapperFunc); + + // Add an auto-pybind-cuda decoration to the wrapper function to further generate the + // host-side binding for the derivative kernel. + // + { + auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); + StringBuilder nameBuilder; + nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff"; + builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + // Build a name for the wrapper function: <original_name>_bwd_diff + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + StringBuilder nameBuilder; + nameBuilder << externCppHint->getName() << "_bwd_diff"; + builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + builder.addPublicDecoration(wrapperFunc); + builder.addKeepAliveDecoration(wrapperFunc); + + builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc); + } + } + } +} + } |
