diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-21 14:00:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-21 14:00:48 -0700 |
| commit | 5b2eb06816521cc0fcfe03258452560bd200002d (patch) | |
| tree | dc06cc626ff0059dded3f4245f9309b3071ae94c /source/slang/slang-ir-pytorch-cpp-binding.cpp | |
| parent | af8ce68e9fd7b6255b6e4e9e9524a285497116dc (diff) | |
Various slangpy fixes. (#3227)
* Make dynamic cast transparent through `IRAttributedType`.
* Add [CUDAXxx] variant of attributes.
* Support marshaling of vector types.
* Wrap cuda kernels in `extern "C"` block.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 51 |
1 files changed, 43 insertions, 8 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index c723902de..41665ddf7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -193,7 +193,7 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst IRIntegerValue i = 0; for (auto field : structType->getFields()) { - auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; @@ -315,23 +315,38 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } -IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr) +IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as<IRBasicType>(type)) + if (as<IRBasicType>(type) || as<IRVectorType>(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); - +#if 0 + case kIROp_VectorType: + { + // Create a new struct type representing the vector. + auto hostStructType = builder->createStructType(); + const char* names[4] = { "x", "y", "z", "w" }; + for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++) + { + auto key = builder->createStructKey(); + if (i < 4) + builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); + builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType()); + } + return hostStructType; + } +#endif case kIROp_StructType: { // Create a new struct type with translated fields. List<IRType*> fieldTypes; for (auto field : as<IRStructType>(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType())); + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func)); } auto hostStructType = builder->createStructType(); @@ -348,12 +363,14 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* si } if (sink) - sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type); + sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type, func); return nullptr; } IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) { + if (hostType == cudaType) + return inst; if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType)) return inst; @@ -361,7 +378,18 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); - +#if 0 + case kIROp_VectorType: + { + List<IRInst*> args; + auto hostStructType = cast<IRStructType>(hostType); + for (auto field : hostStructType->getFields()) + { + args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); + } + return builder->emitMakeVector(cudaType, args); + } +#endif case kIROp_StructType: { auto cudaStructType = cast<IRStructType>(cudaType); @@ -522,7 +550,7 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) { - auto type = translateToHostType(builder, param->getDataType(), sink); + auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; auto hostParam = builder->emitParam(type); @@ -859,6 +887,7 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); + builder.addExternCDecoration(globalInst); } } } @@ -915,7 +944,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration<IRCudaKernelDecoration>()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. @@ -971,7 +1003,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration<IRCudaKernelDecoration>()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. |
