diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-pytorch-cpp-binding.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 601 |
1 files changed, 360 insertions, 241 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 2c91dc3b6..fa6361434 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -1,16 +1,15 @@ #include "slang-ir-pytorch-cpp-binding.h" -#include "slang-ir.h" -#include "slang-ir-insts.h" + #include "slang-diagnostics.h" #include "slang-ir-autodiff.h" +#include "slang-ir-insts.h" #include "slang-ir-lower-cuda-builtin-types.h" +#include "slang-ir.h" namespace Slang { // Convert a type to a target tuple type. -static IRType* translateToTupleType( - IRBuilder& builder, - IRType* type) +static IRType* translateToTupleType(IRBuilder& builder, IRType* type) { if (as<IRVoidType>(type)) return type; @@ -31,7 +30,8 @@ static IRType* translateToTupleType( { elementTypes.addRange(matrixType->getElementType()); } - auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + auto elementTupleType = + builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); List<IRType*> rowTypes; for (IRIntegerValue i = 0; i < colCount->getValue(); i++) { @@ -66,7 +66,9 @@ static IRType* translateToTupleType( { subElementTypes.addRange(subElementType); } - return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer()); + return builder.getTargetTupleType( + (UInt)subElementTypes.getCount(), + subElementTypes.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { @@ -118,19 +120,24 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) List<IRType*> colTypes; for (IRIntegerValue j = 0; j < colCount->getValue(); j++) { - auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto elementVal = + builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; colElements.add(tupleElement); colTypes.add(tupleElement->getFullType()); } - auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer()); + auto rowType = + builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer()); rowTypes.add(rowType); - rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer())); + rowElements.add(builder.emitMakeTargetTuple( + rowType, + (UInt)colElements.getCount(), + colElements.getBuffer())); } return builder.emitMakeTargetTuple( - builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()), + builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()), (UInt)rowElements.getCount(), rowElements.getBuffer()); } @@ -145,15 +152,20 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { - auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto elementVal = + builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } - auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); - return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + auto resultType = + builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple( + resultType, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { @@ -166,15 +178,20 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto elementVal = + builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } - auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); - return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + auto resultType = + builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple( + resultType, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { @@ -189,8 +206,12 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } - auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); - return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + auto resultType = + builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple( + resultType, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (as<IRTargetTupleType>(type)) { @@ -218,17 +239,30 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst SLANG_ASSERT(rowCount && colCount); List<IRInst*> resultElements; - auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer()); + auto rowType = builder.getTargetTupleType( + (UInt)colCount->getValue(), + List<IRType*>() + .makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()) + .getBuffer()); for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) { - auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i)); + auto rowElement = builder.emitTargetTupleGetElement( + rowType, + val, + builder.getIntValue(builder.getIntType(), i)); for (IRIntegerValue j = 0; j < colCount->getValue(); j++) { - auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j)); + auto element = builder.emitTargetTupleGetElement( + matrixType->getElementType(), + rowElement, + builder.getIntValue(builder.getIntType(), j)); resultElements.add(element); } } - return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + return builder.emitMakeMatrix( + type, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (auto vectorType = as<IRVectorType>(type)) { @@ -241,13 +275,19 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst auto elementType = vectorType->getElementType(); for (IRIntegerValue i = 0; i < count->getValue(); i++) { - auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement( + elementType, + val, + builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); } - return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + return builder.emitMakeVector( + type, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { @@ -261,7 +301,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst auto tupleElementType = translateToTupleType(builder, elementType); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement( + tupleElementType, + val, + builder.getIntValue(builder.getIntType(), i)); // Make a name hint: <valname>_<i> if (auto nameHint = val->findDecoration<IRNameHintDecoration>()) @@ -270,13 +313,16 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst newName << nameHint->getName() << "_" << i; builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice()); } - + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); } - return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + return builder.emitMakeArray( + type, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { @@ -285,10 +331,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst for (auto field : structType->getFields()) { auto tupleElement = builder.emitTargetTupleGetElement( - translateToTupleType(builder, field->getFieldType()), + translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i)); - + // Make a name hint: <valname>_<fieldname> if (auto nameHint = val->findDecoration<IRNameHintDecoration>()) { @@ -300,13 +346,17 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } } - auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); + auto convertedElement = + makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); i++; } - return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + return builder.emitMakeStruct( + type, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } else if (as<IRTargetTupleType>(type)) { @@ -326,7 +376,10 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) auto hostReturnType = translateToTupleType(builder, func->getResultType()); if (!hostReturnType) { - sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); + sink->diagnose( + func->sourceLoc, + Diagnostics::invalidTorchKernelReturnType, + func->getResultType()); return; } List<IRType*> hostParamTypes; @@ -385,7 +438,8 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) builder.setInsertBefore(kernelDispatch); List<IRInst*> kernelArgs; auto kernelArgCount = kernelDispatch->getArgCount(); - auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()), + auto argArrayType = builder.getArrayType( + builder.getPtrType(builder.getVoidType()), builder.getIntValue(builder.getIntType(), kernelArgCount)); auto argArrayVar = builder.emitVar(argArrayType); for (UInt i = 0; i < kernelArgCount; i++) @@ -393,10 +447,14 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) auto arg = kernelDispatch->getArg(i); auto argVar = builder.emitVar(arg->getFullType()); builder.emitStore(argVar, arg); - auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i)); + auto addr = builder.emitElementAddress( + argArrayVar, + builder.getIntValue(builder.getIntType(), i)); builder.emitStore(addr, argVar); } - auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0)); + auto argArrayPtr = builder.emitElementAddress( + argArrayVar, + builder.getIntValue(builder.getIntType(), 0)); builder.emitCudaKernelLaunch( kernelDispatch->getBaseFn(), kernelDispatch->getDispatchSize(), @@ -408,7 +466,8 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) else if (auto getView = as<IRTorchTensorGetView>(inst)) { builder.setInsertBefore(getView); - auto makeView = builder.emitMakeTensorView(getView->getFullType(), inst->getOperand(0)); + auto makeView = + builder.emitMakeTensorView(getView->getFullType(), inst->getOperand(0)); getView->replaceUsesWith(makeView); instsToRemove.add(getView); } @@ -425,7 +484,11 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } -IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) +IRType* translateToHostType( + IRBuilder* builder, + IRType* type, + IRInst* func, + DiagnosticSink* sink = nullptr) { if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type)) return type; @@ -435,37 +498,37 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag case kIROp_TensorViewType: return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); case kIROp_StructType: - { - // Create a new struct type with translated fields. - List<IRType*> fieldTypes; - List<IRNameHintDecoration*> fieldNames; - for (auto field : as<IRStructType>(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink)); - fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>()); - } - auto hostStructType = builder->createStructType(); + // Create a new struct type with translated fields. + List<IRType*> fieldTypes; + List<IRNameHintDecoration*> fieldNames; + for (auto field : as<IRStructType>(type)->getFields()) + { + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink)); + fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>()); + } + auto hostStructType = builder->createStructType(); - // Add fields to the struct. - for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++) - { - auto structKey = builder->createStructKey(); - if (fieldNames[i]) - builder->addNameHintDecoration(structKey, fieldNames[i]->getName()); - builder->createStructField(hostStructType, structKey, fieldTypes[i]); - } + // Add fields to the struct. + for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++) + { + auto structKey = builder->createStructKey(); + if (fieldNames[i]) + builder->addNameHintDecoration(structKey, fieldNames[i]->getName()); + builder->createStructField(hostStructType, structKey, fieldTypes[i]); + } - return hostStructType; - } + return hostStructType; + } case kIROp_ArrayType: - { - auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink); - if (!elementType) - return nullptr; - return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount()); - } - default: - break; + { + auto elementType = + translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink); + if (!elementType) + return nullptr; + return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount()); + } + default: break; } if (sink) @@ -476,10 +539,11 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag // Propagates name hints through field extracts. IRInst* propagateNameHint(IRBuilder* builder, IRFieldExtract* inst) { - // If the field has a name hint, propagate it to the inst by appending the field name to the inst name - // (which must be fetched from the inst's name hint decoration). + // If the field has a name hint, propagate it to the inst by appending the field name to the + // inst name (which must be fetched from the inst's name hint decoration). // - // This is useful for propagating the name hint from a struct field to the inst that extracts the field. + // This is useful for propagating the name hint from a struct field to the inst that extracts + // the field. if (auto nameHint = inst->getField()->findDecoration<IRNameHintDecoration>()) { if (auto instNameHint = inst->getBase()->findDecoration<IRNameHintDecoration>()) @@ -519,69 +583,77 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp switch (cudaType->getOp()) { - case kIROp_TensorViewType: - return builder->emitMakeTensorView(cudaType, inst); + case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); case kIROp_StructType: - { - auto cudaStructType = cast<IRStructType>(cudaType); - auto hostStructType = cast<IRStructType>(hostType); + { + auto cudaStructType = cast<IRStructType>(cudaType); + auto hostStructType = cast<IRStructType>(hostType); - List<IRStructField*> cudaFields; - for (auto field : cudaStructType->getFields()) - cudaFields.add(field); + List<IRStructField*> cudaFields; + for (auto field : cudaStructType->getFields()) + cudaFields.add(field); - List<IRStructField*> hostFields; - for (auto field : hostStructType->getFields()) - hostFields.add(field); + List<IRStructField*> hostFields; + for (auto field : hostStructType->getFields()) + hostFields.add(field); - List<IRInst*> resultFields; - for (auto ii = 0; ii < cudaFields.getCount(); ii++) - { - auto cudaField = cudaFields[ii]; - auto hostField = hostFields[ii]; - auto cudaFieldType = cudaField->getFieldType(); - auto hostFieldType = hostField->getFieldType(); - auto castedField = castHostToCUDAType( - builder, - hostFieldType, - cudaFieldType, - propagateNameHint( + List<IRInst*> resultFields; + for (auto ii = 0; ii < cudaFields.getCount(); ii++) + { + auto cudaField = cudaFields[ii]; + auto hostField = hostFields[ii]; + auto cudaFieldType = cudaField->getFieldType(); + auto hostFieldType = hostField->getFieldType(); + auto castedField = castHostToCUDAType( builder, - cast<IRFieldExtract>(builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())))); + hostFieldType, + cudaFieldType, + propagateNameHint( + builder, + cast<IRFieldExtract>( + builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())))); + + SLANG_RELEASE_ASSERT(castedField); + resultFields.add(castedField); + } - SLANG_RELEASE_ASSERT(castedField); - resultFields.add(castedField); + return builder->emitMakeStruct( + cudaType, + (UInt)resultFields.getCount(), + resultFields.getBuffer()); } - - return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer()); - } case kIROp_ArrayType: - { - auto cudaArrayType = cast<IRArrayType>(cudaType); - auto hostArrayType = cast<IRArrayType>(hostType); - - List<IRInst*> resultElements; - for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++) { - auto cudaElementType = cudaArrayType->getElementType(); - auto hostElementType = hostArrayType->getElementType(); - auto castedElement = castHostToCUDAType( - builder, - hostElementType, - cudaElementType, - propagateNameHint( - builder, - cast<IRGetElement>(builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i))))); - - SLANG_RELEASE_ASSERT(castedElement); - resultElements.add(castedElement); + auto cudaArrayType = cast<IRArrayType>(cudaType); + auto hostArrayType = cast<IRArrayType>(hostType); + + List<IRInst*> resultElements; + for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); + i++) + { + auto cudaElementType = cudaArrayType->getElementType(); + auto hostElementType = hostArrayType->getElementType(); + auto castedElement = castHostToCUDAType( + builder, + hostElementType, + cudaElementType, + propagateNameHint( + builder, + cast<IRGetElement>(builder->emitElementExtract( + inst, + builder->getIntValue(builder->getIntType(), i))))); + + SLANG_RELEASE_ASSERT(castedElement); + resultElements.add(castedElement); + } + + return builder->emitMakeArray( + cudaType, + (UInt)resultElements.getCount(), + resultElements.getBuffer()); } - return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer()); - } - - default: - break; + default: break; } // If translateToHostType worked correctly, there should be no unhandled cases here. @@ -610,7 +682,8 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host { if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) { - paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName()))); + paramNames.add( + builder->emitGetNativeString(builder->getStringValue(nameHint->getName()))); } else { @@ -618,7 +691,8 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host argNameBuilder << "param"; argNameBuilder << paramCount; - paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice()))); + paramNames.add(builder->emitGetNativeString( + builder->getStringValue(argNameBuilder.getUnownedSlice()))); } paramCount++; } @@ -628,41 +702,48 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host // Check for py-export decoration. if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>()) { - paramTypeNames.add( - builder->emitGetNativeString( - builder->getStringValue( - pyExportHint->getExportName()))); + paramTypeNames.add(builder->emitGetNativeString( + builder->getStringValue(pyExportHint->getExportName()))); } else { paramTypeNames.add( - builder->emitGetNativeString( - builder->getStringValue( - UnownedStringSlice("")))); + builder->emitGetNativeString(builder->getStringValue(UnownedStringSlice("")))); } } // Create a target-tuple-type for the names auto paramNamesTupleType = builder->getTargetTupleType( (UInt)paramNames.getCount(), - List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer()); - auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer()); + List<IRType*>() + .makeRepeated(builder->getNativeStringType(), paramNames.getCount()) + .getBuffer()); + auto paramNamesTuple = builder->emitMakeTargetTuple( + paramNamesTupleType, + paramNames.getCount(), + paramNames.getBuffer()); // Create a target-tuple-type for the type names auto paramTypeNamesTupleType = builder->getTargetTupleType( (UInt)paramTypeNames.getCount(), - List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer()); - auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer()); + List<IRType*>() + .makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()) + .getBuffer()); + auto paramTypeNamesTuple = builder->emitMakeTargetTuple( + paramTypeNamesTupleType, + paramTypeNames.getCount(), + paramTypeNames.getBuffer()); // Find the fwd-diff function name (blank string indicates no fwd-diff) IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice("")); if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>()) { auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc(); - + if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>()) { - fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName())); + fwdDiffName = builder->emitGetNativeString( + builder->getStringValue(fwdDiffFuncExternHint->getName())); } } @@ -671,20 +752,23 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>()) { auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc(); - + if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>()) { - bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName())); + bwdDiffName = builder->emitGetNativeString( + builder->getStringValue(bwdDiffFuncExternHint->getName())); } } - + auto stringType = builder->getNativeStringType(); auto returnTupleType = builder->getTargetTupleType( 4, - List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer()); + List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType) + .getBuffer()); // Create a target-tuple-type for the names - auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName ); + auto returnTupleArgs = + List<IRInst*>(paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName); auto returnTuple = builder->emitMakeTargetTuple( returnTupleType, returnTupleArgs.getCount(), @@ -705,17 +789,21 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host builder->addKeepAliveDecoration(reflectionFunc); } -IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) +IRInst* generateHostParamForCUDAParam( + IRBuilder* builder, + IRParam* param, + DiagnosticSink* sink, + IRType** outType = nullptr) { auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; - + if (!type || sink->getErrorCount() > 0) { return nullptr; } - + auto hostParam = builder->emitParam(type); // Add a namehint to the param @@ -723,11 +811,11 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno { builder->addNameHintDecoration(hostParam, nameHint->getName()); } - + // Then cast the param to the appropriate type. if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam)) return castedParam; - + return nullptr; } @@ -736,7 +824,7 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink) // If it's a basic type, we're done. if (as<IRBasicType>(type) || as<IRVoidType>(type)) return; - + // If it's a struct type, mark for py-export. if (auto structType = as<IRStructType>(type)) { @@ -744,7 +832,7 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink) // If it already has a py-export decoration, we're done. if (!structType->findDecoration<IRPyExportDecoration>()) - { + { // Look for a name hint. UnownedStringSlice nameHint; if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>()) @@ -804,7 +892,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // Emit a function that returns a py::list. // The list will contain the names of all the fields of the type. // - + if (!type->findDecoration<IRPyExportDecoration>()) return; @@ -813,64 +901,81 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) auto reflFunc = builder.createFunc(); builder.setInsertInto(reflFunc); builder.emitBlock(); - + List<IRInst*> fieldNames; List<IRInst*> fieldTypeNames; switch (type->getOp()) { case kIROp_StructType: - { - for (auto field : as<IRStructType>(type)->getFields()) { - auto structKey = field->getKey(); - // Look for a name hint. - if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>()) - fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName()))); - else - fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); - - auto fieldType = field->getFieldType(); - auto exportName = tryGetExportTypeName(&builder, fieldType); - - if (exportName.getLength() > 0) - fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice()))); - else - fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); + for (auto field : as<IRStructType>(type)->getFields()) + { + auto structKey = field->getKey(); + // Look for a name hint. + if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>()) + fieldNames.add(builder.emitGetNativeString( + builder.getStringValue(nameHintDecoration->getName()))); + else + fieldNames.add(builder.emitGetNativeString( + builder.getStringValue(UnownedStringSlice("")))); + + auto fieldType = field->getFieldType(); + auto exportName = tryGetExportTypeName(&builder, fieldType); + + if (exportName.getLength() > 0) + fieldTypeNames.add(builder.emitGetNativeString( + builder.getStringValue(exportName.getUnownedSlice()))); + else + fieldTypeNames.add(builder.emitGetNativeString( + builder.getStringValue(UnownedStringSlice("")))); + } + break; } - break; - } case kIROp_ArrayType: - { - auto elementType = as<IRArrayType>(type)->getElementType(); - fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type")))); - fieldTypeNames.add( - builder.emitGetNativeString( - builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice()))); - - auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount()); - fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size")))); - - StringBuilder elementCountStr; - elementCountStr << elementCount->getValue(); - fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice()))); - break; - } - default: - break; + { + auto elementType = as<IRArrayType>(type)->getElementType(); + fieldNames.add( + builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type")))); + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue( + tryGetExportTypeName(&builder, elementType).getUnownedSlice()))); + + auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount()); + fieldNames.add( + builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size")))); + + StringBuilder elementCountStr; + elementCountStr << elementCount->getValue(); + fieldTypeNames.add(builder.emitGetNativeString( + builder.getStringValue(elementCountStr.getUnownedSlice()))); + break; + } + default: break; } auto _nameListTupleType = builder.getTargetTupleType( (UInt)fieldNames.getCount(), - List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer()); - auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer()); + List<IRType*>() + .makeRepeated(builder.getNativeStringType(), fieldNames.getCount()) + .getBuffer()); + auto nameListTuple = builder.emitMakeTargetTuple( + _nameListTupleType, + (UInt)fieldNames.getCount(), + fieldNames.getBuffer()); auto _typeNameListTupleType = builder.getTargetTupleType( (UInt)fieldTypeNames.getCount(), - List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer()); - auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer()); + List<IRType*>() + .makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()) + .getBuffer()); + auto typeNameListTuple = builder.emitMakeTargetTuple( + _typeNameListTupleType, + (UInt)fieldTypeNames.getCount(), + fieldTypeNames.getBuffer()); - auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer()); + auto _nameAndTypeTupleType = builder.getTargetTupleType( + 2, + List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer()); auto nameAndTypeTuple = builder.emitMakeTargetTuple( _nameAndTypeTupleType, 2, @@ -884,7 +989,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // Set function name. StringBuilder reflFuncExportName; reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice(); - + builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addKeepAliveDecoration(reflFunc); @@ -895,25 +1000,25 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) // Check that the function has an auto-bind decoration if (!func->findDecoration<IRAutoPyBindCudaDecoration>()) return nullptr; - - // We will create a CudaHost function that will call func. + + // We will create a CudaHost function that will call func. // But before that, we need to determine the type of CudaHost. - // + // // To determine the type, first we will append two uint3 parameters to the function. // with the names "__blockSize" and "__gridSize", these will serve as input block and // grid size parameters for the launch. - // + // // Then, we will go over the parameters of func, and find a host-mapping for each type // by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and // IRTensorViewType to IRTorchTensorType. - // - // Finally, we will create a CudaHost function and transfer the name of func over to - // the generated method. - // + // + // Finally, we will create a CudaHost function and transfer the name of func over to + // the generated method. + // // The function body will first perform any conversion logic needed to convert the // parameters from the CudaHost types to the types of func, and then use dispatch_kernel // to dispatch func with the given block and grid size. - // + // // Create new function. IRBuilder builder(func->getModule()); @@ -926,7 +1031,7 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) // Add the two uint3 parameters auto uint3Type = builder.getVectorType(builder.getUIntType(), 3); - + auto blockSizeParam = builder.emitParam(uint3Type); hostParamTypes.add(uint3Type); builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize")); @@ -939,7 +1044,7 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) for (auto param : func->getFirstBlock()->getParams()) { IRType* hostParamType; - mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType)); + mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType)); hostParamTypes.add(hostParamType); markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type? } @@ -952,15 +1057,15 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) gridSizeParam, mappedParams.getCount(), mappedParams.getBuffer()); - + builder.emitReturn(); IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType()); hostFunc->setFullType(hostFuncType); - - // Add a torch entry point decoration to the host function to mark + + // Add a torch entry point decoration to the host function to mark // for further processing. - // + // if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>()) { // Mark for further processing of torch-specific insts. @@ -1012,7 +1117,8 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) List<LoweredBuiltinTypeInfo> loweredParamTypes; for (auto param : params) { - LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType()); + LoweredBuiltinTypeInfo info = + lowerType(&typeLoweringEnv, &builder, param->getDataType()); loweredParamTypes.add(info); if (info.convertLoweredToOriginal != nullptr) @@ -1021,14 +1127,17 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) auto originalType = param->getDataType(); param->setFullType(info.loweredType); - // Call the conversion function to convert the lowered parameter to the original parameter. + // Call the conversion function to convert the lowered parameter to the original + // parameter. List<IRInst*> args; args.add(param); setInsertAfterOrdinaryInst(&builder, param); - auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args); + auto convertedParam = + builder.emitCallInst(originalType, info.convertLoweredToOriginal, args); - // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction. + // Replace all uses of the lowered parameter with the converted parameter, except + // for the call instruction. for (auto use = param->firstUse; use;) { auto nextUse = use->nextUse; @@ -1042,21 +1151,22 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) use->set(convertedParam); use = nextUse; } - + changed = true; } } if (!changed) continue; - + fixUpFuncType(func); - // Go through any calls to this function and insert a call to converOriginalToLowered before the call. + // Go through any calls to this function and insert a call to converOriginalToLowered before + // the call. for (auto use = func->firstUse; use;) { auto nextUse = use->nextUse; - + if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser())) { auto user = use->getUser(); @@ -1072,7 +1182,8 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) callBuilder.setInsertBefore(user); for (auto arg : argsList) { - if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr) + if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != + nullptr) { auto convertedArg = callBuilder.emitCallInst( loweredParamTypes[convertedArgs.getCount()].loweredType, @@ -1088,7 +1199,7 @@ void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) // Rebuild the call/dispatch inst. IRInst* newCall = nullptr; - + if (as<IRCall>(user)) newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs); else if (auto dispatchInst = as<IRDispatchKernel>(user)) @@ -1233,7 +1344,8 @@ void handleAutoBindNames(IRModule* module) void removeTorchAndCUDAEntryPoints(IRModule* module) { - // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration. + // Go through global insts, find cuda & torch related entry points and remove the keep-alive + // decoration. IRBuilder builder(module); for (auto globalInst : module->getGlobalInsts()) { @@ -1259,25 +1371,26 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) { if (!as<IRFunc>(globalInst)) continue; - + // Look for methods marked with auto-bind and are differentiable. if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>()) { - if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() || + if (globalInst->findDecoration<IRForwardDifferentiableDecoration>() || globalInst->findDecoration<IRBackwardDifferentiableDecoration>()) { // We'll generate a wrapper for this method that calls fwd_diff(fn) - // but an important thing to note is that we won't actually employ the usual - // differentiable typing rules. We'll assume none of the parameters are - // differentiable & throw a warning if some are. This is because, for the auto-binding - // scenario, we expect to only see tensor types, and their differentiation is handled using - // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives - // - // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over - // the original function's signature. - // In the future, when we update the type system to be able to specify the corresponding pair type, - // we can update this logic. - // + // but an important thing to note is that we won't actually employ the usual + // differentiable typing rules. We'll assume none of the parameters are + // differentiable & throw a warning if some are. This is because, for the + // auto-binding scenario, we expect to only see tensor types, and their + // differentiation is handled using tensor _pair_ types which handle the + // differentiable loads/stores through custom derivatives + // + // For now, the user is expected to explicitly use the tensor pair types, so we will + // simply copy over the original function's signature. In the future, when we update + // the type system to be able to specify the corresponding pair type, we can update + // this logic. + // // Create a new wrapper function. IRBuilder builder(module); @@ -1291,7 +1404,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) for (auto param : func->getFirstBlock()->getParams()) { auto newParam = builder.emitParam(param->getFullType()); - + // Copy over the name hint. if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) builder.addNameHintDecoration(newParam, nameHint->getName()); @@ -1303,7 +1416,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func); auto fwdDiffCall = builder.emitCallInst( - func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + func->getResultType(), + fwdDiffFunc, + params.getCount(), + params.getBuffer()); builder.emitReturn(fwdDiffCall); @@ -1314,8 +1430,8 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCDecoration(wrapperFunc); } - // Add an auto-pybind-cuda decoration to the wrapper function to further generate the - // host-side binding for the derivative kernel. + // Add an auto-pybind-cuda decoration to the wrapper function to further generate + // the host-side binding for the derivative kernel. // { auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); @@ -1323,7 +1439,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff"; builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - + // Build a name for the wrapper function: <original_name>_fwd_diff if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) { @@ -1341,7 +1457,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) { // The reasoning for the reverse-mode is the same as the forward-mode version // (see above) - // + // // Create a new wrapper function. IRBuilder builder(module); @@ -1355,7 +1471,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) for (auto param : func->getFirstBlock()->getParams()) { auto newParam = builder.emitParam(param->getFullType()); - + // Copy over the name hint. if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) builder.addNameHintDecoration(newParam, nameHint->getName()); @@ -1367,7 +1483,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func); auto fwdDiffCall = builder.emitCallInst( - func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + func->getResultType(), + fwdDiffFunc, + params.getCount(), + params.getBuffer()); builder.emitReturn(fwdDiffCall); @@ -1378,8 +1497,8 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCDecoration(wrapperFunc); } - // Add an auto-pybind-cuda decoration to the wrapper function to further generate the - // host-side binding for the derivative kernel. + // Add an auto-pybind-cuda decoration to the wrapper function to further generate + // the host-side binding for the derivative kernel. // { auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); @@ -1387,7 +1506,7 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff"; builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - + // Build a name for the wrapper function: <original_name>_bwd_diff if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) { @@ -1404,4 +1523,4 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) } } -} +} // namespace Slang |
