#include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-diagnostics.h" namespace Slang { static bool getHostReturnTypeImpl(List& elementTypes, IRBuilder& builder, IRType* type) { bool isValid = true; if (as(type)) return true; if (as(type)) elementTypes.add(type); else if (as(type)) elementTypes.add(type); else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); if (!count) { return false; } for (IRIntegerValue i = 0; i < count->getValue(); i++) { elementTypes.addRange(vectorType->getElementType()); } } else if (auto arrayType = as(type)) { auto arraySize = as(arrayType->getElementCount()); if (!arraySize) { return false; } List subElementTypes; isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType()); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { elementTypes.addRange(subElementTypes); } } else if (auto structType = as(type)) { for (auto field : structType->getFields()) { isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType()); } } else { return false; } return isValid; } static IRType* getHostReturnType(IRBuilder& builder, IRType* type) { List types; bool isValid = getHostReturnTypeImpl(types, builder, type); if (isValid) return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer()); return nullptr; } static void flattenToTupleImpl(List& result, IRBuilder& builder, IRInst* val) { auto type = val->getDataType(); if (as(type)) return; if (as(type)) result.add(val); else if (as(type)) result.add(val); else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); if (!count) { return; } for (IRIntegerValue i = 0; i < count->getValue(); i++) { result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i))); } } else if (auto arrayType = as(type)) { auto arraySize = as(arrayType->getElementCount()); if (!arraySize) { return; } for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); flattenToTupleImpl(result, builder, elementVal); } } else if (auto structType = as(type)) { for (auto field : structType->getFields()) { auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); flattenToTupleImpl(result, builder, elementVal); } } } static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val) { List vals; flattenToTupleImpl(vals, builder, val); return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer()); } static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) { IRBuilder builder(func); builder.setInsertBefore(func); auto hostReturnType = getHostReturnType(builder, func->getResultType()); if (!hostReturnType) { sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); return; } List hostParamTypes; auto funcType = as(func->getDataType()); for (UInt i = 0; i < funcType->getParamCount(); i++) { hostParamTypes.add(funcType->getParamType(i)); } auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); func->setFullType(bindingFuncType); builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); List instsToRemove; for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) { if (auto kernelDispatch = as(inst)) { builder.setInsertBefore(kernelDispatch); List kernelArgs; auto kernelArgCount = kernelDispatch->getArgCount(); auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()), builder.getIntValue(builder.getIntType(), kernelArgCount)); auto argArrayVar = builder.emitVar(argArrayType); for (UInt i = 0; i < kernelArgCount; i++) { auto arg = kernelDispatch->getArg(i); auto argVar = builder.emitVar(arg->getFullType()); builder.emitStore(argVar, arg); auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i)); builder.emitStore(addr, argVar); } auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0)); builder.emitCudaKernelLaunch( kernelDispatch->getBaseFn(), kernelDispatch->getDispatchSize(), kernelDispatch->getThreadGroupSize(), argArrayPtr, builder.emitGetTorchCudaStream()); instsToRemove.add(inst); } else if (auto getView = as(inst)) { builder.setInsertBefore(getView); auto makeView = builder.emitMakeTensorView(getView->getFullType(), allocator, inst->getOperand(0)); getView->replaceUsesWith(makeView); instsToRemove.add(getView); } else if (auto ret = as(inst)) { builder.setInsertBefore(ret); auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal()); ret->setOperand(0, retVal); } } } for (auto inst : instsToRemove) inst->removeAndDeallocate(); } void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { List workList; List cudaKernels; for (auto globalInst : module->getGlobalInsts()) { auto func = as(globalInst); if (!func) continue; if (func->findDecoration()) { workList.add(func); } else if (func->findDecoration()) { cudaKernels.add(func); } else { // Remove all other export decorations if this is not a cuda host func. if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); } } for (auto func : workList) generateCppBindingForFunc(func, sink); for (auto func : cudaKernels) { for (auto block = func->getFirstBlock(); block;) { auto nextBlock = block->getNextBlock(); block->removeAndDeallocate(); block = nextBlock; } } } // Remove all [TorchEntryPoint] functions when emitting CUDA source. void removeTorchKernels(IRModule* module) { for (auto globalInst : module->getGlobalInsts()) { if (!as(globalInst)) continue; if (globalInst->findDecoration()) globalInst->removeAndDeallocate(); } } }