summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-27 23:00:42 -0700
committerGitHub <noreply@github.com>2023-03-27 23:00:42 -0700
commit0a6926003fd2300858e3089fe82f421543852395 (patch)
tree19865fa9eb69373f0c0c16b7fac4993f67aa2b20 /source
parentd120fec7e81bbd5e8cf2c551b573feaf6678b43d (diff)
Translate all composed types into tuple types in pyBind. (#2744)
* Translate all composed types into tuple types in pyBind. * Delete temp file. * Fix get tuple element code emit logic. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-cpp.cpp11
-rw-r--r--source/slang/slang-emit-torch.cpp12
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp202
-rw-r--r--source/slang/slang-ir.cpp6
6 files changed, 202 insertions, 39 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 346926712..a178dfe67 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1212,6 +1212,17 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit(")");
return true;
}
+ case kIROp_GetTargetTupleElement:
+ {
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ m_writer->emit("std::get<");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(">(");
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, prec));
+ m_writer->emit(")");
+ return true;
+ }
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_FloatCast:
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index ef67c520a..877c1dc03 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -83,7 +83,17 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitOperand(arg, getInfo(EmitOp::General));
}
m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
- switch (inst->getDataType()->getOperand(0)->getOp())
+
+ // Get the element type of the tensor.
+ auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
+
+ // If instType is a vector type, then we need to get the element type.
+ if (auto vectorType = as<IRVectorType>(instType))
+ {
+ instType = vectorType->getElementType();
+ }
+
+ switch (instType->getOp())
{
case kIROp_FloatType:
m_writer->emit("kFloat32");
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 68f1a28e6..e58094b15 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -315,6 +315,7 @@ INST(MakeArrayFromElement, makeArrayFromElement, 1, 0)
INST(MakeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
INST(MakeTargetTuple, makeTuple, 0, 0)
+INST(GetTargetTupleElement, getTargetTupleElement, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
INST(MakeResultValue, makeResultValue, 1, 0)
INST(MakeResultError, makeResultError, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4cdf6c749..f5b03eb45 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2315,6 +2315,13 @@ struct IRGetTupleElement : IRInst
IRInst* getElementIndex() { return getOperand(1); }
};
+struct IRGetTargetTupleElement : IRInst
+{
+ IR_LEAF_ISA(GetTargetTupleElement)
+ IRInst* getTuple() { return getOperand(0); }
+ IRInst* getElementIndex() { return getOperand(1); }
+};
+
// An Instruction that creates a differential pair value from a
// primal and differential.
@@ -3031,6 +3038,8 @@ public:
IRInst* emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args);
+ IRInst* emitTargetTupleGetElement(IRType* elementType, IRInst* targetTupleVal, IRInst* indexVal);
+
IRInst* emitMakeTuple(IRType* type, UInt count, IRInst* const* args);
IRInst* emitMakeTuple(UInt count, IRInst* const* args);
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index e33adec1d..eb81bfd8c 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -5,113 +5,204 @@
namespace Slang
{
-static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type)
+// Convert a type to a target tuple type.
+static IRType* translateToTupleType(IRBuilder& builder, IRType* type)
{
- bool isValid = true;
if (as<IRVoidType>(type))
- return true;
+ return type;
if (as<IRBasicType>(type))
- elementTypes.add(type);
+ return type;
else if (as<IRTorchTensorType>(type))
- elementTypes.add(type);
+ return type;
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
if (!count)
{
- return false;
+ return nullptr;
}
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
elementTypes.addRange(vectorType->getElementType());
}
+ return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize)
{
- return false;
+ return nullptr;
}
List<IRType*> subElementTypes;
- isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType());
+ auto subElementType = translateToTupleType(builder, arrayType->getElementType());
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- elementTypes.addRange(subElementTypes);
+ subElementTypes.addRange(subElementType);
}
+ return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
+ List<IRType*> elementTypes;
for (auto field : structType->getFields())
{
- isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType());
+ auto fieldType = translateToTupleType(builder, field->getFieldType());
+ if (!fieldType)
+ {
+ return nullptr;
+ }
+ elementTypes.addRange(fieldType);
}
+ return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
else
{
- return false;
+ return nullptr;
}
- return isValid;
-}
-
-static IRType* getHostReturnType(IRBuilder& builder, IRType* type)
-{
- List<IRType*> types;
- bool isValid = getHostReturnTypeImpl(types, builder, type);
- if (isValid)
- return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer());
- return nullptr;
}
-static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val)
+// Convert a value to a target tuple type.
+static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
{
auto type = val->getDataType();
if (as<IRVoidType>(type))
- return;
+ return val;
if (as<IRBasicType>(type))
- result.add(val);
+ return val;
else if (as<IRTorchTensorType>(type))
- result.add(val);
+ return val;
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
if (!count)
{
- return;
+ return nullptr;
}
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < count->getValue(); i++)
{
- result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i)));
+ auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
else if (auto arrayType = as<IRArrayType>(type))
{
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize)
{
- return;
+ return nullptr;
}
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
- flattenToTupleImpl(result, builder, elementVal);
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
else if (auto structType = as<IRStructType>(type))
{
+ List<IRInst*> resultElements;
+ List<IRType*> elementTypes;
for (auto field : structType->getFields())
{
auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey());
- flattenToTupleImpl(result, builder, elementVal);
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ resultElements.add(tupleElement);
+ elementTypes.add(tupleElement->getFullType());
}
+ auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else
+ {
+ return nullptr;
}
}
-static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val)
+// Convert a target tuple type to a value.
+static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst* val)
{
- List<IRInst*> vals;
- flattenToTupleImpl(vals, builder, val);
- return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer());
+ if (as<IRVoidType>(type))
+ return val;
+ if (as<IRBasicType>(type))
+ return val;
+ else if (as<IRTorchTensorType>(type))
+ return val;
+ else if (auto vectorType = as<IRVectorType>(type))
+ {
+ auto count = as<IRIntLit>(vectorType->getElementCount());
+ if (!count)
+ {
+ return nullptr;
+ }
+ List<IRInst*> resultElements;
+ auto elementType = vectorType->getElementType();
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ }
+ return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize)
+ {
+ return nullptr;
+ }
+ List<IRInst*> resultElements;
+ auto elementType = arrayType->getElementType();
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ }
+ return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ List<IRInst*> resultElements;
+ IRIntegerValue i = 0;
+ for (auto field : structType->getFields())
+ {
+ auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i));
+ auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
+ if (!convertedElement)
+ return nullptr;
+ resultElements.add(convertedElement);
+ i++;
+ }
+ return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
+ else
+ {
+ return nullptr;
+ }
}
static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
@@ -119,7 +210,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
IRBuilder builder(func);
builder.setInsertBefore(func);
- auto hostReturnType = getHostReturnType(builder, func->getResultType());
+ auto hostReturnType = translateToTupleType(builder, func->getResultType());
if (!hostReturnType)
{
sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType());
@@ -129,15 +220,50 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
auto funcType = as<IRFuncType>(func->getDataType());
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
- hostParamTypes.add(funcType->getParamType(i));
+ hostParamTypes.add(translateToTupleType(builder, funcType->getParamType(i)));
}
auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType);
func->setFullType(bindingFuncType);
builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType));
List<IRInst*> instsToRemove;
+ List<IRInst*> oldParams;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ oldParams.add(param);
+ }
+
+ List<IRInst*> newParams;
+ for (auto param : oldParams)
+ {
+ auto paramType = param->getFullType();
+ auto newParamType = translateToTupleType(builder, paramType);
+ if (!newParamType)
+ {
+ sink->diagnose(param->sourceLoc, Diagnostics::invalidTorchKernelParamType, paramType);
+ return;
+ }
+ auto newParam = builder.emitParam(newParamType);
+ newParams.add(newParam);
+ }
+
+ // Convert all new parameters from tuples to their original types.
+ for (Index i = 0; i < newParams.getCount(); i++)
+ {
+ auto oldParam = oldParams[i];
+ auto newParam = newParams[i];
+ auto convertedParam = makeValueFromTargetTuple(builder, oldParam->getFullType(), newParam);
+ if (!convertedParam)
+ {
+ return;
+ }
+ oldParam->replaceUsesWith(convertedParam);
+ oldParam->removeAndDeallocate();
+ }
+
+ auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType));
+
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
@@ -177,7 +303,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
else if (auto ret = as<IRReturn>(inst))
{
builder.setInsertBefore(ret);
- auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal());
+ auto retVal = makeTargetTuple(builder, ret->getVal());
ret->setOperand(0, retVal);
}
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6ce54a948..76e889780 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3705,6 +3705,12 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeTargetTuple, count, args);
}
+ IRInst* IRBuilder::emitTargetTupleGetElement(IRType* elementType, IRInst* targetTupleVal, IRInst* indexVal)
+ {
+ IRInst* args[] = {targetTupleVal, indexVal};
+ return emitIntrinsicInst(elementType, kIROp_GetTargetTupleElement, 2, args);
+ }
+
IRInst* IRBuilder::emitMakeTuple(UInt count, IRInst* const* args)
{
List<IRType*> types;