summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp202
1 files changed, 164 insertions, 38 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index e33adec1d..eb81bfd8c 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -5,113 +5,204 @@
namespace Slang
{
-static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type)
+// Convert a type to a target tuple type.
+static IRType* translateToTupleType(IRBuilder& builder, IRType* type)
{
- bool isValid = true;
if (as<IRVoidType>(type))
- return true;
+ return type;
if (as<IRBasicType>(type))
- elementTypes.add(type);
+ return type;
else if (as<IRTorchTensorType>(type))
- elementTypes.add(type);
+ return type;
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
if (!count)
{
- return false;
+ return nullptr;
}
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
elementTypes.addRange(vectorType->getElementType());
}
+ return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize)
{
- return false;
+ return nullptr;
}
List<IRType*> subElementTypes;
- isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType());
+ auto subElementType = translateToTupleType(builder, arrayType->getElementType());
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- elementTypes.addRange(subElementTypes);
+ subElementTypes.addRange(subElementType);
}
+ return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
+ List<IRType*> elementTypes;
for (auto field : structType->getFields())
{
- isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType());
+ auto fieldType = translateToTupleType(builder, field->getFieldType());
+ if (!fieldType)
+ {
+ return nullptr;
+ }
+ elementTypes.addRange(fieldType);
}
+ return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
else
{
- return false;
+ return nullptr;
}
- return isValid;
-}
-
-static IRType* getHostReturnType(IRBuilder& builder, IRType* type)
-{
- List<IRType*> types;
- bool isValid = getHostReturnTypeImpl(types, builder, type);
- if (isValid)
- return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer());
- return nullptr;
}
-static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val)
+// Convert a value to a target tuple type.
+static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
{
auto type = val->getDataType();
if (as<IRVoidType>(type))
- return;
+ return val;
if (as<IRBasicType>(type))
- result.add(val);
+ return val;
else if (as<IRTorchTensorType>(type))
- result.add(val);
+ return val;
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
if (!count)
{
- return;
+ return nullptr;
}
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
- result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i)));
+ auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize)
{
- return;
+ return nullptr;
}
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
- flattenToTupleImpl(result, builder, elementVal);
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (auto field : structType->getFields())
{
auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey());
- flattenToTupleImpl(result, builder, elementVal);
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else
+ {
+ return nullptr;
}
}
-static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val)
+// Convert a target tuple type to a value.
+static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst* val)
{
- List<IRInst*> vals;
- flattenToTupleImpl(vals, builder, val);
- return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer());
+ if (as<IRVoidType>(type))
+ return val;
+ if (as<IRBasicType>(type))
+ return val;
+ else if (as<IRTorchTensorType>(type))
+ return val;
+ else if (auto vectorType = as<IRVectorType>(type))
+ {
+ auto count = as<IRIntLit>(vectorType->getElementCount());
+ if (!count)
+ {
+ return nullptr;
+ }
+ List<IRInst*> resultElements;
+ auto elementType = vectorType->getElementType();
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ }
+ return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize)
+ {
+ return nullptr;
+ }
+ List<IRInst*> resultElements;
+ auto elementType = arrayType->getElementType();
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ }
+ return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ List<IRInst*> resultElements;
+ IRIntegerValue i = 0;
+ for (auto field : structType->getFields())
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ i++;
+ }
+ return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else
+ {
+ return nullptr;
+ }
}
static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
@@ -119,7 +210,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
IRBuilder builder(func);
builder.setInsertBefore(func);
- auto hostReturnType = getHostReturnType(builder, func->getResultType());
+ auto hostReturnType = translateToTupleType(builder, func->getResultType());
if (!hostReturnType)
{
sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType());
@@ -129,15 +220,50 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
auto funcType = as<IRFuncType>(func->getDataType());
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
- hostParamTypes.add(funcType->getParamType(i));
+ hostParamTypes.add(translateToTupleType(builder, funcType->getParamType(i)));
}
auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType);
func->setFullType(bindingFuncType);
builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType));
List<IRInst*> instsToRemove;
+ List<IRInst*> oldParams;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ oldParams.add(param);
+ }
+
+ List<IRInst*> newParams;
+ for (auto param : oldParams)
+ {
+ auto paramType = param->getFullType();
+ auto newParamType = translateToTupleType(builder, paramType);
+ if (!newParamType)
+ {
+ sink->diagnose(param->sourceLoc, Diagnostics::invalidTorchKernelParamType, paramType);
+ return;
+ }
+ auto newParam = builder.emitParam(newParamType);
+ newParams.add(newParam);
+ }
+
+ // Convert all new parameters from tuples to their original types.
+ for (Index i = 0; i < newParams.getCount(); i++)
+ {
+ auto oldParam = oldParams[i];
+ auto newParam = newParams[i];
+ auto convertedParam = makeValueFromTargetTuple(builder, oldParam->getFullType(), newParam);
+ if (!convertedParam)
+ {
+ return;
+ }
+ oldParam->replaceUsesWith(convertedParam);
+ oldParam->removeAndDeallocate();
+ }
+
+ auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType));
+
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
@@ -177,7 +303,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
else if (auto ret = as<IRReturn>(inst))
{
builder.setInsertBefore(ret);
- auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal());
+ auto retVal = makeTargetTuple(builder, ret->getVal());
ret->setOperand(0, retVal);
}
}