summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp51
1 files changed, 43 insertions, 8 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index c723902de..41665ddf7 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -193,7 +193,7 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
IRIntegerValue i = 0;
for (auto field : structType->getFields())
{
- auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
if (!convertedElement)
return nullptr;
@@ -315,23 +315,38 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
inst->removeAndDeallocate();
}
-IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr)
+IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr)
{
- if (as<IRBasicType>(type))
+ if (as<IRBasicType>(type) || as<IRVectorType>(type))
return type;
switch (type->getOp())
{
case kIROp_TensorViewType:
return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
-
+#if 0
+ case kIROp_VectorType:
+ {
+ // Create a new struct type representing the vector.
+ auto hostStructType = builder->createStructType();
+ const char* names[4] = { "x", "y", "z", "w" };
+ for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++)
+ {
+ auto key = builder->createStructKey();
+ if (i < 4)
+ builder->addNameHintDecoration(key, UnownedStringSlice(names[i]));
+ builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType());
+ }
+ return hostStructType;
+ }
+#endif
case kIROp_StructType:
{
// Create a new struct type with translated fields.
List<IRType*> fieldTypes;
for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType()));
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func));
}
auto hostStructType = builder->createStructType();
@@ -348,12 +363,14 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* si
}
if (sink)
- sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type);
+ sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type, func);
return nullptr;
}
IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
{
+ if (hostType == cudaType)
+ return inst;
if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType))
return inst;
@@ -361,7 +378,18 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
{
case kIROp_TensorViewType:
return builder->emitMakeTensorView(cudaType, inst);
-
+#if 0
+ case kIROp_VectorType:
+ {
+ List<IRInst*> args;
+ auto hostStructType = cast<IRStructType>(hostType);
+ for (auto field : hostStructType->getFields())
+ {
+ args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey()));
+ }
+ return builder->emitMakeVector(cudaType, args);
+ }
+#endif
case kIROp_StructType:
{
auto cudaStructType = cast<IRStructType>(cudaType);
@@ -522,7 +550,7 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr)
{
- auto type = translateToHostType(builder, param->getDataType(), sink);
+ auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
auto hostParam = builder->emitParam(type);
@@ -859,6 +887,7 @@ void handleAutoBindNames(IRModule* module)
nameBuilder << "__kernel__" << externCppHint->getName();
externCppHint->removeAndDeallocate();
builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
+ builder.addExternCDecoration(globalInst);
}
}
}
@@ -915,7 +944,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
// If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
if (func->findDecoration<IRCudaKernelDecoration>())
+ {
builder.addCudaKernelDecoration(wrapperFunc);
+ builder.addExternCDecoration(wrapperFunc);
+ }
// Add an auto-pybind-cuda decoration to the wrapper function to further generate the
// host-side binding for the derivative kernel.
@@ -971,7 +1003,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
// If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
if (func->findDecoration<IRCudaKernelDecoration>())
+ {
builder.addCudaKernelDecoration(wrapperFunc);
+ builder.addExternCDecoration(wrapperFunc);
+ }
// Add an auto-pybind-cuda decoration to the wrapper function to further generate the
// host-side binding for the derivative kernel.