summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp673
1 files changed, 653 insertions, 20 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index d59d57474..c0adef436 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -2,11 +2,14 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-diagnostics.h"
+#include "slang-ir-autodiff.h"
namespace Slang
{
// Convert a type to a target tuple type.
-static IRType* translateToTupleType(IRBuilder& builder, IRType* type)
+static IRType* translateToTupleType(
+ IRBuilder& builder,
+ IRType* type)
{
if (as<IRVoidType>(type))
return type;
@@ -312,34 +315,490 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
inst->removeAndDeallocate();
}
+IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr)
+{
+ if (as<IRBasicType>(type))
+ return type;
+
+ switch (type->getOp())
+ {
+ case kIROp_TensorViewType:
+ return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
+
+ case kIROp_StructType:
+ {
+ // Create a new struct type with translated fields.
+ List<IRType*> fieldTypes;
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ fieldTypes.add(translateToHostType(builder, field->getFieldType()));
+ }
+ auto hostStructType = builder->createStructType();
+
+ // Add fields to the struct.
+ for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
+ {
+ builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]);
+ }
+
+ return hostStructType;
+ }
+ default:
+ break;
+ }
+
+ if (sink)
+ sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type);
+ return nullptr;
+}
+
+IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
+{
+ if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType))
+ return inst;
+
+ switch (cudaType->getOp())
+ {
+ case kIROp_TensorViewType:
+ return builder->emitMakeTensorView(cudaType, inst);
+
+ case kIROp_StructType:
+ {
+ auto cudaStructType = cast<IRStructType>(cudaType);
+ auto hostStructType = cast<IRStructType>(hostType);
+
+ List<IRStructField*> cudaFields;
+ for (auto field : cudaStructType->getFields())
+ cudaFields.add(field);
+
+ List<IRStructField*> hostFields;
+ for (auto field : hostStructType->getFields())
+ hostFields.add(field);
+
+ List<IRInst*> resultFields;
+ for (auto ii = 0; ii < cudaFields.getCount(); ii++)
+ {
+ auto cudaField = cudaFields[ii];
+ auto hostField = hostFields[ii];
+ auto cudaFieldType = cudaField->getFieldType();
+ auto hostFieldType = hostField->getFieldType();
+ auto castedField = castHostToCUDAType(
+ builder,
+ hostFieldType,
+ cudaFieldType,
+ builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()));
+
+ SLANG_RELEASE_ASSERT(castedField);
+ resultFields.add(castedField);
+ }
+
+ return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
+ }
+
+ default:
+ break;
+ }
+
+ // If translateToHostType worked correctly, we shouldn't get here.
+ SLANG_UNREACHABLE("unhandled type");
+}
+
+void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc)
+{
+ // Given a func with torch binding, we'll generate a reflection function that returns
+ // a tuple where the first element is another tuple of parameter names, the second
+ // element is a string containing the name of the fwd-diff function, and the third
+ // element is a string containing the name of the bwd-diff function.
+ //
+
+ // Create a new function.
+ auto reflectionFunc = builder->createFunc();
+ builder->setInsertInto(reflectionFunc);
+ builder->emitBlock();
+
+ // Go through func & generate a tuple of parameter names.
+ List<IRInst*> paramNames;
+ List<IRInst*> paramTypeNames;
+ UIndex paramCount = 0;
+ for (auto param : hostFunc->getFirstBlock()->getParams())
+ {
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ {
+ paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName())));
+ }
+ else
+ {
+ StringBuilder argNameBuilder;
+ argNameBuilder << "param";
+ argNameBuilder << paramCount;
+
+ paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice())));
+ }
+ paramCount++;
+ }
+
+ for (auto param : kernelFunc->getParams())
+ {
+ // Check for py-export decoration.
+ if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>())
+ {
+ paramTypeNames.add(
+ builder->emitGetNativeString(
+ builder->getStringValue(
+ pyExportHint->getExportName())));
+ }
+ else
+ {
+ paramTypeNames.add(
+ builder->emitGetNativeString(
+ builder->getStringValue(
+ UnownedStringSlice(""))));
+ }
+ }
+
+ // Create a target-tuple-type for the names
+ auto paramNamesTupleType = builder->getTargetTupleType(
+ (UInt)paramNames.getCount(),
+ List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer());
+ auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer());
+
+ // Create a target-tuple-type for the type names
+ auto paramTypeNamesTupleType = builder->getTargetTupleType(
+ (UInt)paramTypeNames.getCount(),
+ List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer());
+ auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer());
+
+ // Find the fwd-diff function name (blank string indicates no fwd-diff)
+ IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice(""));
+ if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>())
+ {
+ auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc();
+
+ if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>())
+ {
+ fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName()));
+ }
+ }
+
+ // Find the bwd-diff function name (blank string indicates no bwd-diff)
+ IRInst* bwdDiffName = builder->getStringValue(UnownedStringSlice(""));
+ if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>())
+ {
+ auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc();
+
+ if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>())
+ {
+ bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName()));
+ }
+ }
+
+ auto stringType = builder->getNativeStringType();
+ auto returnTupleType = builder->getTargetTupleType(
+ 4,
+ List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer());
+
+ // Create a target-tuple-type for the names
+ auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName );
+ auto returnTuple = builder->emitMakeTargetTuple(
+ returnTupleType,
+ returnTupleArgs.getCount(),
+ returnTupleArgs.getBuffer());
+ builder->emitReturn(returnTuple);
+
+ // Set function type.
+ auto funcType = builder->getFuncType(List<IRType*>(), returnTupleType);
+ reflectionFunc->setFullType(funcType);
+
+ // Set function name.
+ StringBuilder reflFuncExportName;
+ auto hostFuncExportName = hostFunc->findDecoration<IRExternCppDecoration>()->getName();
+ reflFuncExportName << "__funcinfo__" << hostFuncExportName;
+
+ builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
+ builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
+ builder->addPublicDecoration(reflectionFunc);
+ builder->addKeepAliveDecoration(reflectionFunc);
+}
+
+IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr)
+{
+ auto typeMap = [&](IRType* t) -> IRType* {
+ if (auto tensorViewType = as<IRTensorViewType>(t))
+ return builder->getTorchTensorType(tensorViewType->getElementType());
+ };
+
+ auto type = translateToHostType(builder, param->getDataType(), sink);
+ if (outType)
+ *outType = type;
+ auto hostParam = builder->emitParam(type);
+ // Add a namehint to the param by appending the suffix "_host".
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ {
+ builder->addNameHintDecoration(hostParam, nameHint->getName());
+ }
+
+ // Then cast the param to the appropriate type.
+ if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam))
+ return castedParam;
+
+ return nullptr;
+}
+
+void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
+{
+ // If it's a basic type, we're done.
+ if (as<IRBasicType>(type) || as<IRVoidType>(type))
+ return;
+
+ // If it's a struct type, mark for py-export.
+ if (auto structType = as<IRStructType>(type))
+ {
+ IRBuilder builder(structType->getModule());
+
+ // If it already has a py-export decoration, we're done.
+ if (!structType->findDecoration<IRPyExportDecoration>())
+ {
+ // Look for a name hint.
+ UnownedStringSlice nameHint;
+ if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>())
+ nameHint = nameHintDecoration->getName();
+ else
+ {
+ // If there's no name hint, we can't export this type.
+ SLANG_UNEXPECTED("struct marked for export has no name");
+ }
+
+ builder.addPyExportDecoration(structType, nameHint);
+ }
+
+ for (auto field : structType->getFields())
+ {
+ markTypeForPyExport(field->getFieldType(), sink);
+ }
+ return;
+ }
+}
+
+void generateReflectionForType(IRType* type, DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+ // Emit a function that returns a py::list.
+ // The list will contain the names of all the fields of the type.
+ //
+
+ // TODO: Fix this to avoid emitting the same type reflection multiple times.
+ if (!type->findDecoration<IRPyExportDecoration>())
+ return;
+
+ IRBuilder builder(type->getModule());
+
+ auto reflFunc = builder.createFunc();
+ builder.setInsertInto(reflFunc);
+ builder.emitBlock();
+
+ List<IRInst*> fieldNames;
+ List<IRInst*> fieldTypeNames;
+
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ {
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ auto structKey = field->getKey();
+ // Look for a name hint.
+ if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>())
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName())));
+ else
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
+
+ if (!field->getFieldType()->findDecoration<IRPyExportDecoration>())
+ {
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
+ continue;
+ }
+
+ auto fieldType = field->getFieldType();
+
+ fieldTypeNames.add(
+ builder.emitGetNativeString(
+ builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName())));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ auto _nameListTupleType = builder.getTargetTupleType(
+ (UInt)fieldNames.getCount(),
+ List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer());
+ auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer());
+
+ auto _typeNameListTupleType = builder.getTargetTupleType(
+ (UInt)fieldTypeNames.getCount(),
+ List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer());
+ auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer());
+
+ auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer());
+ auto nameAndTypeTuple = builder.emitMakeTargetTuple(
+ _nameAndTypeTupleType,
+ 2,
+ List<IRInst*>(nameListTuple, typeNameListTuple).getBuffer());
+ builder.emitReturn(nameAndTypeTuple);
+
+ // Set function type.
+ auto funcType = builder.getFuncType(List<IRType*>(), _nameAndTypeTupleType);
+ reflFunc->setFullType(funcType);
+
+ // Set function name.
+ StringBuilder reflFuncExportName;
+ reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName();
+
+ builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
+ builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
+ builder.addPublicDecoration(reflFunc);
+ builder.addKeepAliveDecoration(reflFunc);
+}
+
+IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
+{
+ // Check that the function has an auto-bind decoration
+ if (!func->findDecoration<IRAutoPyBindCudaDecoration>())
+ return nullptr;
+
+ // We will create a CudaHost function that will call func.
+ // But before that, we need to determine the type of CudaHost.
+ //
+ // To determine the type, first we will append two uint3 parameters to the function.
+ // with the names "__blockSize" and "__gridSize", these will serve as input block and
+ // grid size parameters for the launch.
+ //
+ // Then, we will go over the parameters of func, and find a host-mapping for each type
+ // by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and
+ // IRTensorViewType to IRTorchTensorType.
+ //
+ // Finally, we will create a CudaHost function and transfer the name of func over to
+ // the generated method.
+ //
+ // The function body will first perform any conversion logic needed to convert the
+ // parameters from the CudaHost types to the types of func, and then use dispatch_kernel
+ // to dispatch func with the given block and grid size.
+ //
+
+ // Create new function.
+ IRBuilder builder(func->getModule());
+
+ auto hostFunc = builder.createFunc();
+ builder.setInsertInto(hostFunc);
+ builder.emitBlock();
+
+ List<IRType*> hostParamTypes;
+
+ // Add the two uint3 parameters
+ auto uint3Type = builder.getVectorType(builder.getUIntType(), 3);
+
+ auto blockSizeParam = builder.emitParam(uint3Type);
+ hostParamTypes.add(uint3Type);
+ builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize"));
+
+ auto gridSizeParam = builder.emitParam(uint3Type);
+ hostParamTypes.add(uint3Type);
+ builder.addNameHintDecoration(gridSizeParam, UnownedStringSlice("__gridSize"));
+
+ List<IRInst*> mappedParams;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ IRType* hostParamType;
+ mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType));
+ hostParamTypes.add(hostParamType);
+ markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type?
+ }
+
+ // Dispatch the original function.
+ builder.emitDispatchKernelInst(
+ builder.getVoidType(),
+ func,
+ blockSizeParam,
+ gridSizeParam,
+ mappedParams.getCount(),
+ mappedParams.getBuffer());
+
+ builder.emitReturn();
+
+ IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType());
+ hostFunc->setFullType(hostFuncType);
+
+ // Add a torch entry point decoration to the host function to mark
+ // for further processing.
+ //
+ if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ // Mark for further processing of torch-specific insts.
+ builder.addTorchEntryPointDecoration(hostFunc, pybindCudaHint->getFunctionName());
+ // Mark for host-side emit logic.
+ builder.addCudaHostDecoration(hostFunc);
+ // Keep alive. This method will be accessed externally.
+ builder.addPublicDecoration(hostFunc);
+ builder.addKeepAliveDecoration(hostFunc);
+ }
+
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ // Transfer to the host function.
+ builder.addExternCppDecoration(hostFunc, externCppHint->getName());
+ }
+
+ if (auto exportInfoHint = func->findDecoration<IRAutoPyBindExportInfoDecoration>())
+ generateReflectionFunc(&builder, func, hostFunc);
+
+ return hostFunc;
+}
+
void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
{
List<IRFunc*> workList;
List<IRFunc*> cudaKernels;
+ List<IRFunc*> autoBindRequests;
+ List<IRType*> typesToExport;
for (auto globalInst : module->getGlobalInsts())
{
- auto func = as<IRFunc>(globalInst);
- if (!func)
- continue;
- if (func->findDecoration<IRTorchEntryPointDecoration>())
- {
- workList.add(func);
- }
- else if (func->findDecoration<IRCudaKernelDecoration>())
+ if (auto func = as<IRFunc>(globalInst))
{
- cudaKernels.add(func);
+ if (func->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ autoBindRequests.add(func);
+ }
+ if (func->findDecoration<IRTorchEntryPointDecoration>())
+ {
+ workList.add(func);
+ }
+ else if (func->findDecoration<IRCudaKernelDecoration>())
+ {
+ cudaKernels.add(func);
+ }
+ else
+ {
+ // Remove all other export decorations if this is not a cuda host func.
+ if (auto decor = func->findDecoration<IRPublicDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRHLSLExportDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRDllExportDecoration>())
+ decor->removeAndDeallocate();
+ }
}
- else
+ }
+
+ // Generate CUDA wrappers for all functions that have the auto-bind decoration.
+ for (auto func : autoBindRequests)
+ {
+ if (auto hostFunc = generateCUDAWrapperForFunc(func, sink))
{
- // Remove all other export decorations if this is not a cuda host func.
- if (auto decor = func->findDecoration<IRPublicDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRHLSLExportDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRKeepAliveDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRDllExportDecoration>())
- decor->removeAndDeallocate();
+ // Add generated wrapper to worklist for python bindings.
+ workList.add(hostFunc);
}
}
@@ -355,6 +814,20 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
block = nextBlock;
}
}
+
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto type = as<IRType>(globalInst))
+ {
+ if (type->findDecoration<IRPyExportDecoration>())
+ {
+ typesToExport.add(type);
+ }
+ }
+ }
+
+ for (auto type : typesToExport)
+ generateReflectionForType(type, sink);
}
// Remove all [TorchEntryPoint] functions when emitting CUDA source.
@@ -372,4 +845,164 @@ void removeTorchKernels(IRModule* module)
inst->removeAndDeallocate();
}
+void handleAutoBindNames(IRModule* module)
+{
+ // We need to rewrite extern-cpp names for functions that have an auto-bind decoration.
+ // since the name needs to be used for the host function.
+ //
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ // Find an extern decoration on the original function, and append a prefix to the name.
+ if (auto externCppHint = globalInst->findDecoration<IRExternCppDecoration>())
+ {
+ IRBuilder builder(module);
+
+ // Change the name of the original function.
+ StringBuilder nameBuilder;
+ nameBuilder << "__kernel__" << externCppHint->getName();
+ externCppHint->removeAndDeallocate();
+ builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
+ }
+ }
+ }
+}
+
+void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (!as<IRFunc>(globalInst))
+ continue;
+
+ // Look for methods marked with auto-bind and are differentiable.
+ if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() ||
+ globalInst->findDecoration<IRBackwardDifferentiableDecoration>())
+ {
+ // We'll generate a wrapper for this method that calls fwd_diff(fn)
+ // but an important thing to note is that we won't actually employ the usual
+ // differentiable typing rules. We'll assume none of the parameters are
+ // differentiable & throw a warning if some are. This is because, for the auto-binding
+ // scenario, we expect to only see tensor types, and their differentiation is handled using
+ // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives
+ //
+ // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over
+ // the original function's signature.
+ // In the future, when we update the type system to be able to specify the corresponding pair type,
+ // we can update this logic.
+ //
+
+ // Create a new wrapper function.
+ IRBuilder builder(module);
+ auto func = cast<IRFunc>(globalInst);
+ auto wrapperFunc = builder.createFunc();
+ builder.setInsertInto(wrapperFunc);
+ builder.emitBlock();
+
+ // Clone the parameter list.
+ List<IRInst*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(builder.emitParam(param->getFullType()));
+ }
+
+ wrapperFunc->setFullType(func->getFullType());
+
+ auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func);
+ auto fwdDiffCall = builder.emitCallInst(
+ func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+
+ builder.emitReturn(fwdDiffCall);
+
+ // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
+ if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>())
+ builder.addCudaKernelDecoration(wrapperFunc);
+
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
+ // host-side binding for the derivative kernel.
+ //
+ {
+ auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
+ StringBuilder nameBuilder;
+ nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff";
+ builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ // Build a name for the wrapper function: <original_name>_fwd_diff
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << externCppHint->getName() << "_fwd_diff";
+ builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ builder.addPublicDecoration(wrapperFunc);
+ builder.addKeepAliveDecoration(wrapperFunc);
+
+ builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc);
+ }
+
+ if (globalInst->findDecoration<IRBackwardDifferentiableDecoration>())
+ {
+ // The reasoning for the reverse-mode is the same as the forward-mode version
+ // (see above)
+ //
+
+ // Create a new wrapper function.
+ IRBuilder builder(module);
+ auto func = cast<IRFunc>(globalInst);
+ auto wrapperFunc = builder.createFunc();
+ builder.setInsertInto(wrapperFunc);
+ builder.emitBlock();
+
+ // Clone the parameter list.
+ List<IRInst*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(builder.emitParam(param->getFullType()));
+ }
+
+ wrapperFunc->setFullType(func->getFullType());
+
+ auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func);
+ auto fwdDiffCall = builder.emitCallInst(
+ func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+
+ builder.emitReturn(fwdDiffCall);
+
+ // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
+ if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>())
+ builder.addCudaKernelDecoration(wrapperFunc);
+
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
+ // host-side binding for the derivative kernel.
+ //
+ {
+ auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
+ StringBuilder nameBuilder;
+ nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff";
+ builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ // Build a name for the wrapper function: <original_name>_bwd_diff
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << externCppHint->getName() << "_bwd_diff";
+ builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ builder.addPublicDecoration(wrapperFunc);
+ builder.addKeepAliveDecoration(wrapperFunc);
+
+ builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc);
+ }
+ }
+ }
+}
+
}