summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-pytorch-cpp-binding.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp601
1 files changed, 360 insertions, 241 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 2c91dc3b6..fa6361434 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -1,16 +1,15 @@
#include "slang-ir-pytorch-cpp-binding.h"
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
+
#include "slang-diagnostics.h"
#include "slang-ir-autodiff.h"
+#include "slang-ir-insts.h"
#include "slang-ir-lower-cuda-builtin-types.h"
+#include "slang-ir.h"
namespace Slang
{
// Convert a type to a target tuple type.
-static IRType* translateToTupleType(
- IRBuilder& builder,
- IRType* type)
+static IRType* translateToTupleType(IRBuilder& builder, IRType* type)
{
if (as<IRVoidType>(type))
return type;
@@ -31,7 +30,8 @@ static IRType* translateToTupleType(
{
elementTypes.addRange(matrixType->getElementType());
}
- auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ auto elementTupleType =
+ builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
List<IRType*> rowTypes;
for (IRIntegerValue i = 0; i < colCount->getValue(); i++)
{
@@ -66,7 +66,9 @@ static IRType* translateToTupleType(
{
subElementTypes.addRange(subElementType);
}
- return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer());
+ return builder.getTargetTupleType(
+ (UInt)subElementTypes.getCount(),
+ subElementTypes.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
@@ -118,19 +120,24 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
List<IRType*> colTypes;
for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
{
- auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto elementVal =
+ builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
auto tupleElement = makeTargetTuple(builder, elementVal);
if (!tupleElement)
return nullptr;
colElements.add(tupleElement);
colTypes.add(tupleElement->getFullType());
}
- auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer());
+ auto rowType =
+ builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer());
rowTypes.add(rowType);
- rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer()));
+ rowElements.add(builder.emitMakeTargetTuple(
+ rowType,
+ (UInt)colElements.getCount(),
+ colElements.getBuffer()));
}
return builder.emitMakeTargetTuple(
- builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()),
+ builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()),
(UInt)rowElements.getCount(),
rowElements.getBuffer());
}
@@ -145,15 +152,20 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
- auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto elementVal =
+ builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
auto tupleElement = makeTargetTuple(builder, elementVal);
if (!tupleElement)
return nullptr;
resultElements.add(tupleElement);
elementTypes.add(tupleElement->getFullType());
}
- auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
- return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ auto resultType =
+ builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(
+ resultType,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
@@ -166,15 +178,20 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto elementVal =
+ builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
auto tupleElement = makeTargetTuple(builder, elementVal);
if (!tupleElement)
return nullptr;
resultElements.add(tupleElement);
elementTypes.add(tupleElement->getFullType());
}
- auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
- return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ auto resultType =
+ builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(
+ resultType,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
@@ -189,8 +206,12 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
resultElements.add(tupleElement);
elementTypes.add(tupleElement->getFullType());
}
- auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
- return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ auto resultType =
+ builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(
+ resultType,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (as<IRTargetTupleType>(type))
{
@@ -218,17 +239,30 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
SLANG_ASSERT(rowCount && colCount);
List<IRInst*> resultElements;
- auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer());
+ auto rowType = builder.getTargetTupleType(
+ (UInt)colCount->getValue(),
+ List<IRType*>()
+ .makeRepeated(matrixType->getElementType(), (Index)colCount->getValue())
+ .getBuffer());
for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
{
- auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i));
+ auto rowElement = builder.emitTargetTupleGetElement(
+ rowType,
+ val,
+ builder.getIntValue(builder.getIntType(), i));
for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
{
- auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j));
+ auto element = builder.emitTargetTupleGetElement(
+ matrixType->getElementType(),
+ rowElement,
+ builder.getIntValue(builder.getIntType(), j));
resultElements.add(element);
}
}
- return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ return builder.emitMakeMatrix(
+ type,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (auto vectorType = as<IRVectorType>(type))
{
@@ -241,13 +275,19 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
auto elementType = vectorType->getElementType();
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
- auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(
+ elementType,
+ val,
+ builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
resultElements.add(convertedElement);
}
- return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ return builder.emitMakeVector(
+ type,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
@@ -261,7 +301,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
auto tupleElementType = translateToTupleType(builder, elementType);
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(
+ tupleElementType,
+ val,
+ builder.getIntValue(builder.getIntType(), i));
// Make a name hint: <valname>_<i>
if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
@@ -270,13 +313,16 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
newName << nameHint->getName() << "_" << i;
builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice());
}
-
+
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
resultElements.add(convertedElement);
}
- return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ return builder.emitMakeArray(
+ type,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
@@ -285,10 +331,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
for (auto field : structType->getFields())
{
auto tupleElement = builder.emitTargetTupleGetElement(
- translateToTupleType(builder, field->getFieldType()),
+ translateToTupleType(builder, field->getFieldType()),
val,
builder.getIntValue(builder.getIntType(), i));
-
+
// Make a name hint: <valname>_<fieldname>
if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
{
@@ -300,13 +346,17 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
}
- auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
+ auto convertedElement =
+ makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
if (!convertedElement)
return nullptr;
resultElements.add(convertedElement);
i++;
}
- return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ return builder.emitMakeStruct(
+ type,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
else if (as<IRTargetTupleType>(type))
{
@@ -326,7 +376,10 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
auto hostReturnType = translateToTupleType(builder, func->getResultType());
if (!hostReturnType)
{
- sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType());
+ sink->diagnose(
+ func->sourceLoc,
+ Diagnostics::invalidTorchKernelReturnType,
+ func->getResultType());
return;
}
List<IRType*> hostParamTypes;
@@ -385,7 +438,8 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
builder.setInsertBefore(kernelDispatch);
List<IRInst*> kernelArgs;
auto kernelArgCount = kernelDispatch->getArgCount();
- auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()),
+ auto argArrayType = builder.getArrayType(
+ builder.getPtrType(builder.getVoidType()),
builder.getIntValue(builder.getIntType(), kernelArgCount));
auto argArrayVar = builder.emitVar(argArrayType);
for (UInt i = 0; i < kernelArgCount; i++)
@@ -393,10 +447,14 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
auto arg = kernelDispatch->getArg(i);
auto argVar = builder.emitVar(arg->getFullType());
builder.emitStore(argVar, arg);
- auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i));
+ auto addr = builder.emitElementAddress(
+ argArrayVar,
+ builder.getIntValue(builder.getIntType(), i));
builder.emitStore(addr, argVar);
}
- auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0));
+ auto argArrayPtr = builder.emitElementAddress(
+ argArrayVar,
+ builder.getIntValue(builder.getIntType(), 0));
builder.emitCudaKernelLaunch(
kernelDispatch->getBaseFn(),
kernelDispatch->getDispatchSize(),
@@ -408,7 +466,8 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
else if (auto getView = as<IRTorchTensorGetView>(inst))
{
builder.setInsertBefore(getView);
- auto makeView = builder.emitMakeTensorView(getView->getFullType(), inst->getOperand(0));
+ auto makeView =
+ builder.emitMakeTensorView(getView->getFullType(), inst->getOperand(0));
getView->replaceUsesWith(makeView);
instsToRemove.add(getView);
}
@@ -425,7 +484,11 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
inst->removeAndDeallocate();
}
-IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr)
+IRType* translateToHostType(
+ IRBuilder* builder,
+ IRType* type,
+ IRInst* func,
+ DiagnosticSink* sink = nullptr)
{
if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type))
return type;
@@ -435,37 +498,37 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
case kIROp_TensorViewType:
return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
case kIROp_StructType:
- {
- // Create a new struct type with translated fields.
- List<IRType*> fieldTypes;
- List<IRNameHintDecoration*> fieldNames;
- for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
- fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>());
- }
- auto hostStructType = builder->createStructType();
+ // Create a new struct type with translated fields.
+ List<IRType*> fieldTypes;
+ List<IRNameHintDecoration*> fieldNames;
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
+ fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>());
+ }
+ auto hostStructType = builder->createStructType();
- // Add fields to the struct.
- for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
- {
- auto structKey = builder->createStructKey();
- if (fieldNames[i])
- builder->addNameHintDecoration(structKey, fieldNames[i]->getName());
- builder->createStructField(hostStructType, structKey, fieldTypes[i]);
- }
+ // Add fields to the struct.
+ for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
+ {
+ auto structKey = builder->createStructKey();
+ if (fieldNames[i])
+ builder->addNameHintDecoration(structKey, fieldNames[i]->getName());
+ builder->createStructField(hostStructType, structKey, fieldTypes[i]);
+ }
- return hostStructType;
- }
+ return hostStructType;
+ }
case kIROp_ArrayType:
- {
- auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink);
- if (!elementType)
- return nullptr;
- return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount());
- }
- default:
- break;
+ {
+ auto elementType =
+ translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink);
+ if (!elementType)
+ return nullptr;
+ return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount());
+ }
+ default: break;
}
if (sink)
@@ -476,10 +539,11 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
// Propagates name hints through field extracts.
IRInst* propagateNameHint(IRBuilder* builder, IRFieldExtract* inst)
{
- // If the field has a name hint, propagate it to the inst by appending the field name to the inst name
- // (which must be fetched from the inst's name hint decoration).
+ // If the field has a name hint, propagate it to the inst by appending the field name to the
+ // inst name (which must be fetched from the inst's name hint decoration).
//
- // This is useful for propagating the name hint from a struct field to the inst that extracts the field.
+ // This is useful for propagating the name hint from a struct field to the inst that extracts
+ // the field.
if (auto nameHint = inst->getField()->findDecoration<IRNameHintDecoration>())
{
if (auto instNameHint = inst->getBase()->findDecoration<IRNameHintDecoration>())
@@ -519,69 +583,77 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
switch (cudaType->getOp())
{
- case kIROp_TensorViewType:
- return builder->emitMakeTensorView(cudaType, inst);
+ case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst);
case kIROp_StructType:
- {
- auto cudaStructType = cast<IRStructType>(cudaType);
- auto hostStructType = cast<IRStructType>(hostType);
+ {
+ auto cudaStructType = cast<IRStructType>(cudaType);
+ auto hostStructType = cast<IRStructType>(hostType);
- List<IRStructField*> cudaFields;
- for (auto field : cudaStructType->getFields())
- cudaFields.add(field);
+ List<IRStructField*> cudaFields;
+ for (auto field : cudaStructType->getFields())
+ cudaFields.add(field);
- List<IRStructField*> hostFields;
- for (auto field : hostStructType->getFields())
- hostFields.add(field);
+ List<IRStructField*> hostFields;
+ for (auto field : hostStructType->getFields())
+ hostFields.add(field);
- List<IRInst*> resultFields;
- for (auto ii = 0; ii < cudaFields.getCount(); ii++)
- {
- auto cudaField = cudaFields[ii];
- auto hostField = hostFields[ii];
- auto cudaFieldType = cudaField->getFieldType();
- auto hostFieldType = hostField->getFieldType();
- auto castedField = castHostToCUDAType(
- builder,
- hostFieldType,
- cudaFieldType,
- propagateNameHint(
+ List<IRInst*> resultFields;
+ for (auto ii = 0; ii < cudaFields.getCount(); ii++)
+ {
+ auto cudaField = cudaFields[ii];
+ auto hostField = hostFields[ii];
+ auto cudaFieldType = cudaField->getFieldType();
+ auto hostFieldType = hostField->getFieldType();
+ auto castedField = castHostToCUDAType(
builder,
- cast<IRFieldExtract>(builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()))));
+ hostFieldType,
+ cudaFieldType,
+ propagateNameHint(
+ builder,
+ cast<IRFieldExtract>(
+ builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()))));
+
+ SLANG_RELEASE_ASSERT(castedField);
+ resultFields.add(castedField);
+ }
- SLANG_RELEASE_ASSERT(castedField);
- resultFields.add(castedField);
+ return builder->emitMakeStruct(
+ cudaType,
+ (UInt)resultFields.getCount(),
+ resultFields.getBuffer());
}
-
- return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
- }
case kIROp_ArrayType:
- {
- auto cudaArrayType = cast<IRArrayType>(cudaType);
- auto hostArrayType = cast<IRArrayType>(hostType);
-
- List<IRInst*> resultElements;
- for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++)
{
- auto cudaElementType = cudaArrayType->getElementType();
- auto hostElementType = hostArrayType->getElementType();
- auto castedElement = castHostToCUDAType(
- builder,
- hostElementType,
- cudaElementType,
- propagateNameHint(
- builder,
- cast<IRGetElement>(builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)))));
-
- SLANG_RELEASE_ASSERT(castedElement);
- resultElements.add(castedElement);
+ auto cudaArrayType = cast<IRArrayType>(cudaType);
+ auto hostArrayType = cast<IRArrayType>(hostType);
+
+ List<IRInst*> resultElements;
+ for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue();
+ i++)
+ {
+ auto cudaElementType = cudaArrayType->getElementType();
+ auto hostElementType = hostArrayType->getElementType();
+ auto castedElement = castHostToCUDAType(
+ builder,
+ hostElementType,
+ cudaElementType,
+ propagateNameHint(
+ builder,
+ cast<IRGetElement>(builder->emitElementExtract(
+ inst,
+ builder->getIntValue(builder->getIntType(), i)))));
+
+ SLANG_RELEASE_ASSERT(castedElement);
+ resultElements.add(castedElement);
+ }
+
+ return builder->emitMakeArray(
+ cudaType,
+ (UInt)resultElements.getCount(),
+ resultElements.getBuffer());
}
- return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer());
- }
-
- default:
- break;
+ default: break;
}
// If translateToHostType worked correctly, there should be no unhandled cases here.
@@ -610,7 +682,8 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
{
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
{
- paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName())));
+ paramNames.add(
+ builder->emitGetNativeString(builder->getStringValue(nameHint->getName())));
}
else
{
@@ -618,7 +691,8 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
argNameBuilder << "param";
argNameBuilder << paramCount;
- paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice())));
+ paramNames.add(builder->emitGetNativeString(
+ builder->getStringValue(argNameBuilder.getUnownedSlice())));
}
paramCount++;
}
@@ -628,41 +702,48 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
// Check for py-export decoration.
if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>())
{
- paramTypeNames.add(
- builder->emitGetNativeString(
- builder->getStringValue(
- pyExportHint->getExportName())));
+ paramTypeNames.add(builder->emitGetNativeString(
+ builder->getStringValue(pyExportHint->getExportName())));
}
else
{
paramTypeNames.add(
- builder->emitGetNativeString(
- builder->getStringValue(
- UnownedStringSlice(""))));
+ builder->emitGetNativeString(builder->getStringValue(UnownedStringSlice(""))));
}
}
// Create a target-tuple-type for the names
auto paramNamesTupleType = builder->getTargetTupleType(
(UInt)paramNames.getCount(),
- List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer());
- auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer());
+ List<IRType*>()
+ .makeRepeated(builder->getNativeStringType(), paramNames.getCount())
+ .getBuffer());
+ auto paramNamesTuple = builder->emitMakeTargetTuple(
+ paramNamesTupleType,
+ paramNames.getCount(),
+ paramNames.getBuffer());
// Create a target-tuple-type for the type names
auto paramTypeNamesTupleType = builder->getTargetTupleType(
(UInt)paramTypeNames.getCount(),
- List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer());
- auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer());
+ List<IRType*>()
+ .makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount())
+ .getBuffer());
+ auto paramTypeNamesTuple = builder->emitMakeTargetTuple(
+ paramTypeNamesTupleType,
+ paramTypeNames.getCount(),
+ paramTypeNames.getBuffer());
// Find the fwd-diff function name (blank string indicates no fwd-diff)
IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice(""));
if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>())
{
auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc();
-
+
if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>())
{
- fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName()));
+ fwdDiffName = builder->emitGetNativeString(
+ builder->getStringValue(fwdDiffFuncExternHint->getName()));
}
}
@@ -671,20 +752,23 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>())
{
auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc();
-
+
if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>())
{
- bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName()));
+ bwdDiffName = builder->emitGetNativeString(
+ builder->getStringValue(bwdDiffFuncExternHint->getName()));
}
}
-
+
auto stringType = builder->getNativeStringType();
auto returnTupleType = builder->getTargetTupleType(
4,
- List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer());
+ List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType)
+ .getBuffer());
// Create a target-tuple-type for the names
- auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName );
+ auto returnTupleArgs =
+ List<IRInst*>(paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName);
auto returnTuple = builder->emitMakeTargetTuple(
returnTupleType,
returnTupleArgs.getCount(),
@@ -705,17 +789,21 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
builder->addKeepAliveDecoration(reflectionFunc);
}
-IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr)
+IRInst* generateHostParamForCUDAParam(
+ IRBuilder* builder,
+ IRParam* param,
+ DiagnosticSink* sink,
+ IRType** outType = nullptr)
{
auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
-
+
if (!type || sink->getErrorCount() > 0)
{
return nullptr;
}
-
+
auto hostParam = builder->emitParam(type);
// Add a namehint to the param
@@ -723,11 +811,11 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
{
builder->addNameHintDecoration(hostParam, nameHint->getName());
}
-
+
// Then cast the param to the appropriate type.
if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam))
return castedParam;
-
+
return nullptr;
}
@@ -736,7 +824,7 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
// If it's a basic type, we're done.
if (as<IRBasicType>(type) || as<IRVoidType>(type))
return;
-
+
// If it's a struct type, mark for py-export.
if (auto structType = as<IRStructType>(type))
{
@@ -744,7 +832,7 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
// If it already has a py-export decoration, we're done.
if (!structType->findDecoration<IRPyExportDecoration>())
- {
+ {
// Look for a name hint.
UnownedStringSlice nameHint;
if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>())
@@ -804,7 +892,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// Emit a function that returns a py::list.
// The list will contain the names of all the fields of the type.
//
-
+
if (!type->findDecoration<IRPyExportDecoration>())
return;
@@ -813,64 +901,81 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
auto reflFunc = builder.createFunc();
builder.setInsertInto(reflFunc);
builder.emitBlock();
-
+
List<IRInst*> fieldNames;
List<IRInst*> fieldTypeNames;
switch (type->getOp())
{
case kIROp_StructType:
- {
- for (auto field : as<IRStructType>(type)->getFields())
{
- auto structKey = field->getKey();
- // Look for a name hint.
- if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>())
- fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName())));
- else
- fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
-
- auto fieldType = field->getFieldType();
- auto exportName = tryGetExportTypeName(&builder, fieldType);
-
- if (exportName.getLength() > 0)
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice())));
- else
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ auto structKey = field->getKey();
+ // Look for a name hint.
+ if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>())
+ fieldNames.add(builder.emitGetNativeString(
+ builder.getStringValue(nameHintDecoration->getName())));
+ else
+ fieldNames.add(builder.emitGetNativeString(
+ builder.getStringValue(UnownedStringSlice(""))));
+
+ auto fieldType = field->getFieldType();
+ auto exportName = tryGetExportTypeName(&builder, fieldType);
+
+ if (exportName.getLength() > 0)
+ fieldTypeNames.add(builder.emitGetNativeString(
+ builder.getStringValue(exportName.getUnownedSlice())));
+ else
+ fieldTypeNames.add(builder.emitGetNativeString(
+ builder.getStringValue(UnownedStringSlice(""))));
+ }
+ break;
}
- break;
- }
case kIROp_ArrayType:
- {
- auto elementType = as<IRArrayType>(type)->getElementType();
- fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type"))));
- fieldTypeNames.add(
- builder.emitGetNativeString(
- builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice())));
-
- auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount());
- fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size"))));
-
- StringBuilder elementCountStr;
- elementCountStr << elementCount->getValue();
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice())));
- break;
- }
- default:
- break;
+ {
+ auto elementType = as<IRArrayType>(type)->getElementType();
+ fieldNames.add(
+ builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type"))));
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(
+ tryGetExportTypeName(&builder, elementType).getUnownedSlice())));
+
+ auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount());
+ fieldNames.add(
+ builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size"))));
+
+ StringBuilder elementCountStr;
+ elementCountStr << elementCount->getValue();
+ fieldTypeNames.add(builder.emitGetNativeString(
+ builder.getStringValue(elementCountStr.getUnownedSlice())));
+ break;
+ }
+ default: break;
}
auto _nameListTupleType = builder.getTargetTupleType(
(UInt)fieldNames.getCount(),
- List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer());
- auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer());
+ List<IRType*>()
+ .makeRepeated(builder.getNativeStringType(), fieldNames.getCount())
+ .getBuffer());
+ auto nameListTuple = builder.emitMakeTargetTuple(
+ _nameListTupleType,
+ (UInt)fieldNames.getCount(),
+ fieldNames.getBuffer());
auto _typeNameListTupleType = builder.getTargetTupleType(
(UInt)fieldTypeNames.getCount(),
- List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer());
- auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer());
+ List<IRType*>()
+ .makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount())
+ .getBuffer());
+ auto typeNameListTuple = builder.emitMakeTargetTuple(
+ _typeNameListTupleType,
+ (UInt)fieldTypeNames.getCount(),
+ fieldTypeNames.getBuffer());
- auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer());
+ auto _nameAndTypeTupleType = builder.getTargetTupleType(
+ 2,
+ List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer());
auto nameAndTypeTuple = builder.emitMakeTargetTuple(
_nameAndTypeTupleType,
2,
@@ -884,7 +989,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// Set function name.
StringBuilder reflFuncExportName;
reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice();
-
+
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addKeepAliveDecoration(reflFunc);
@@ -895,25 +1000,25 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
// Check that the function has an auto-bind decoration
if (!func->findDecoration<IRAutoPyBindCudaDecoration>())
return nullptr;
-
- // We will create a CudaHost function that will call func.
+
+ // We will create a CudaHost function that will call func.
// But before that, we need to determine the type of CudaHost.
- //
+ //
// To determine the type, first we will append two uint3 parameters to the function.
// with the names "__blockSize" and "__gridSize", these will serve as input block and
// grid size parameters for the launch.
- //
+ //
// Then, we will go over the parameters of func, and find a host-mapping for each type
// by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and
// IRTensorViewType to IRTorchTensorType.
- //
- // Finally, we will create a CudaHost function and transfer the name of func over to
- // the generated method.
- //
+ //
+ // Finally, we will create a CudaHost function and transfer the name of func over to
+ // the generated method.
+ //
// The function body will first perform any conversion logic needed to convert the
// parameters from the CudaHost types to the types of func, and then use dispatch_kernel
// to dispatch func with the given block and grid size.
- //
+ //
// Create new function.
IRBuilder builder(func->getModule());
@@ -926,7 +1031,7 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
// Add the two uint3 parameters
auto uint3Type = builder.getVectorType(builder.getUIntType(), 3);
-
+
auto blockSizeParam = builder.emitParam(uint3Type);
hostParamTypes.add(uint3Type);
builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize"));
@@ -939,7 +1044,7 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
for (auto param : func->getFirstBlock()->getParams())
{
IRType* hostParamType;
- mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType));
+ mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType));
hostParamTypes.add(hostParamType);
markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type?
}
@@ -952,15 +1057,15 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
gridSizeParam,
mappedParams.getCount(),
mappedParams.getBuffer());
-
+
builder.emitReturn();
IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType());
hostFunc->setFullType(hostFuncType);
-
- // Add a torch entry point decoration to the host function to mark
+
+ // Add a torch entry point decoration to the host function to mark
// for further processing.
- //
+ //
if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>())
{
// Mark for further processing of torch-specific insts.
@@ -1012,7 +1117,8 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
List<LoweredBuiltinTypeInfo> loweredParamTypes;
for (auto param : params)
{
- LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType());
+ LoweredBuiltinTypeInfo info =
+ lowerType(&typeLoweringEnv, &builder, param->getDataType());
loweredParamTypes.add(info);
if (info.convertLoweredToOriginal != nullptr)
@@ -1021,14 +1127,17 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
auto originalType = param->getDataType();
param->setFullType(info.loweredType);
- // Call the conversion function to convert the lowered parameter to the original parameter.
+ // Call the conversion function to convert the lowered parameter to the original
+ // parameter.
List<IRInst*> args;
args.add(param);
setInsertAfterOrdinaryInst(&builder, param);
- auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args);
+ auto convertedParam =
+ builder.emitCallInst(originalType, info.convertLoweredToOriginal, args);
- // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction.
+ // Replace all uses of the lowered parameter with the converted parameter, except
+ // for the call instruction.
for (auto use = param->firstUse; use;)
{
auto nextUse = use->nextUse;
@@ -1042,21 +1151,22 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
use->set(convertedParam);
use = nextUse;
}
-
+
changed = true;
}
}
if (!changed)
continue;
-
+
fixUpFuncType(func);
- // Go through any calls to this function and insert a call to converOriginalToLowered before the call.
+ // Go through any calls to this function and insert a call to converOriginalToLowered before
+ // the call.
for (auto use = func->firstUse; use;)
{
auto nextUse = use->nextUse;
-
+
if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser()))
{
auto user = use->getUser();
@@ -1072,7 +1182,8 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
callBuilder.setInsertBefore(user);
for (auto arg : argsList)
{
- if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr)
+ if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered !=
+ nullptr)
{
auto convertedArg = callBuilder.emitCallInst(
loweredParamTypes[convertedArgs.getCount()].loweredType,
@@ -1088,7 +1199,7 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
// Rebuild the call/dispatch inst.
IRInst* newCall = nullptr;
-
+
if (as<IRCall>(user))
newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs);
else if (auto dispatchInst = as<IRDispatchKernel>(user))
@@ -1233,7 +1344,8 @@ void handleAutoBindNames(IRModule* module)
void removeTorchAndCUDAEntryPoints(IRModule* module)
{
- // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration.
+ // Go through global insts, find cuda & torch related entry points and remove the keep-alive
+ // decoration.
IRBuilder builder(module);
for (auto globalInst : module->getGlobalInsts())
{
@@ -1259,25 +1371,26 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
{
if (!as<IRFunc>(globalInst))
continue;
-
+
// Look for methods marked with auto-bind and are differentiable.
if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>())
{
- if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() ||
+ if (globalInst->findDecoration<IRForwardDifferentiableDecoration>() ||
globalInst->findDecoration<IRBackwardDifferentiableDecoration>())
{
// We'll generate a wrapper for this method that calls fwd_diff(fn)
- // but an important thing to note is that we won't actually employ the usual
- // differentiable typing rules. We'll assume none of the parameters are
- // differentiable & throw a warning if some are. This is because, for the auto-binding
- // scenario, we expect to only see tensor types, and their differentiation is handled using
- // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives
- //
- // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over
- // the original function's signature.
- // In the future, when we update the type system to be able to specify the corresponding pair type,
- // we can update this logic.
- //
+ // but an important thing to note is that we won't actually employ the usual
+ // differentiable typing rules. We'll assume none of the parameters are
+ // differentiable & throw a warning if some are. This is because, for the
+ // auto-binding scenario, we expect to only see tensor types, and their
+ // differentiation is handled using tensor _pair_ types which handle the
+ // differentiable loads/stores through custom derivatives
+ //
+ // For now, the user is expected to explicitly use the tensor pair types, so we will
+ // simply copy over the original function's signature. In the future, when we update
+ // the type system to be able to specify the corresponding pair type, we can update
+ // this logic.
+ //
// Create a new wrapper function.
IRBuilder builder(module);
@@ -1291,7 +1404,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
for (auto param : func->getFirstBlock()->getParams())
{
auto newParam = builder.emitParam(param->getFullType());
-
+
// Copy over the name hint.
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
@@ -1303,7 +1416,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func);
auto fwdDiffCall = builder.emitCallInst(
- func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+ func->getResultType(),
+ fwdDiffFunc,
+ params.getCount(),
+ params.getBuffer());
builder.emitReturn(fwdDiffCall);
@@ -1314,8 +1430,8 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCDecoration(wrapperFunc);
}
- // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
- // host-side binding for the derivative kernel.
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate
+ // the host-side binding for the derivative kernel.
//
{
auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
@@ -1323,7 +1439,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff";
builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
-
+
// Build a name for the wrapper function: <original_name>_fwd_diff
if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
{
@@ -1341,7 +1457,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
{
// The reasoning for the reverse-mode is the same as the forward-mode version
// (see above)
- //
+ //
// Create a new wrapper function.
IRBuilder builder(module);
@@ -1355,7 +1471,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
for (auto param : func->getFirstBlock()->getParams())
{
auto newParam = builder.emitParam(param->getFullType());
-
+
// Copy over the name hint.
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
@@ -1367,7 +1483,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func);
auto fwdDiffCall = builder.emitCallInst(
- func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+ func->getResultType(),
+ fwdDiffFunc,
+ params.getCount(),
+ params.getBuffer());
builder.emitReturn(fwdDiffCall);
@@ -1378,8 +1497,8 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCDecoration(wrapperFunc);
}
- // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
- // host-side binding for the derivative kernel.
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate
+ // the host-side binding for the derivative kernel.
//
{
auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
@@ -1387,7 +1506,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff";
builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
-
+
// Build a name for the wrapper function: <original_name>_bwd_diff
if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
{
@@ -1404,4 +1523,4 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
}
}
-}
+} // namespace Slang