diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 202 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 |
6 files changed, 202 insertions, 39 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 346926712..a178dfe67 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1212,6 +1212,17 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut m_writer->emit(")"); return true; } + case kIROp_GetTargetTupleElement: + { + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + m_writer->emit("std::get<"); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + m_writer->emit(")"); + return true; + } case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: case kIROp_FloatCast: diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index ef67c520a..877c1dc03 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -83,7 +83,17 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(arg, getInfo(EmitOp::General)); } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - switch (inst->getDataType()->getOperand(0)->getOp()) + + // Get the element type of the tensor. + auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as<IRVectorType>(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) { case kIROp_FloatType: m_writer->emit("kFloat32"); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 68f1a28e6..e58094b15 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -315,6 +315,7 @@ INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) INST(MakeTargetTuple, makeTuple, 0, 0) +INST(GetTargetTupleElement, getTargetTupleElement, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) INST(MakeResultValue, makeResultValue, 1, 0) INST(MakeResultError, makeResultError, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4cdf6c749..f5b03eb45 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2315,6 +2315,13 @@ struct IRGetTupleElement : IRInst IRInst* getElementIndex() { return getOperand(1); } }; +struct IRGetTargetTupleElement : IRInst +{ + IR_LEAF_ISA(GetTargetTupleElement) + IRInst* getTuple() { return getOperand(0); } + IRInst* getElementIndex() { return getOperand(1); } +}; + // An Instruction that creates a differential pair value from a // primal and differential. @@ -3031,6 +3038,8 @@ public: IRInst* emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args); + IRInst* emitTargetTupleGetElement(IRType* elementType, IRInst* targetTupleVal, IRInst* indexVal); + IRInst* emitMakeTuple(IRType* type, UInt count, IRInst* const* args); IRInst* emitMakeTuple(UInt count, IRInst* const* args); diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index e33adec1d..eb81bfd8c 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -5,113 +5,204 @@ namespace Slang { -static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type) +// Convert a type to a target tuple type. +static IRType* translateToTupleType(IRBuilder& builder, IRType* type) { - bool isValid = true; if (as<IRVoidType>(type)) - return true; + return type; if (as<IRBasicType>(type)) - elementTypes.add(type); + return type; else if (as<IRTorchTensorType>(type)) - elementTypes.add(type); + return type; else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); if (!count) { - return false; + return nullptr; } + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { elementTypes.addRange(vectorType->getElementType()); } + return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { auto arraySize = as<IRIntLit>(arrayType->getElementCount()); if (!arraySize) { - return false; + return nullptr; } List<IRType*> subElementTypes; - isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType()); + auto subElementType = translateToTupleType(builder, arrayType->getElementType()); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - elementTypes.addRange(subElementTypes); + subElementTypes.addRange(subElementType); } + return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { + List<IRType*> elementTypes; for (auto field : structType->getFields()) { - isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType()); + auto fieldType = translateToTupleType(builder, field->getFieldType()); + if (!fieldType) + { + return nullptr; + } + elementTypes.addRange(fieldType); } + return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else { - return false; + return nullptr; } - return isValid; -} - -static IRType* getHostReturnType(IRBuilder& builder, IRType* type) -{ - List<IRType*> types; - bool isValid = getHostReturnTypeImpl(types, builder, type); - if (isValid) - return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer()); - return nullptr; } -static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val) +// Convert a value to a target tuple type. +static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) { auto type = val->getDataType(); if (as<IRVoidType>(type)) - return; + return val; if (as<IRBasicType>(type)) - result.add(val); + return val; else if (as<IRTorchTensorType>(type)) - result.add(val); + return val; else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); if (!count) { - return; + return nullptr; } + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { - result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i))); + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { auto arraySize = as<IRIntLit>(arrayType->getElementCount()); if (!arraySize) { - return; + return nullptr; } + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); - flattenToTupleImpl(result, builder, elementVal); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (auto field : structType->getFields()) { auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); - flattenToTupleImpl(result, builder, elementVal); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else + { + return nullptr; } } -static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val) +// Convert a target tuple type to a value. +static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst* val) { - List<IRInst*> vals; - flattenToTupleImpl(vals, builder, val); - return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer()); + if (as<IRVoidType>(type)) + return val; + if (as<IRBasicType>(type)) + return val; + else if (as<IRTorchTensorType>(type)) + return val; + else if (auto vectorType = as<IRVectorType>(type)) + { + auto count = as<IRIntLit>(vectorType->getElementCount()); + if (!count) + { + return nullptr; + } + List<IRInst*> resultElements; + auto elementType = vectorType->getElementType(); + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + } + return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) + { + return nullptr; + } + List<IRInst*> resultElements; + auto elementType = arrayType->getElementType(); + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + } + return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else if (auto structType = as<IRStructType>(type)) + { + List<IRInst*> resultElements; + IRIntegerValue i = 0; + for (auto field : structType->getFields()) + { + auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + i++; + } + return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else + { + return nullptr; + } } static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) @@ -119,7 +210,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) IRBuilder builder(func); builder.setInsertBefore(func); - auto hostReturnType = getHostReturnType(builder, func->getResultType()); + auto hostReturnType = translateToTupleType(builder, func->getResultType()); if (!hostReturnType) { sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); @@ -129,15 +220,50 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) auto funcType = as<IRFuncType>(func->getDataType()); for (UInt i = 0; i < funcType->getParamCount(); i++) { - hostParamTypes.add(funcType->getParamType(i)); + hostParamTypes.add(translateToTupleType(builder, funcType->getParamType(i))); } auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); func->setFullType(bindingFuncType); builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); List<IRInst*> instsToRemove; + List<IRInst*> oldParams; + for (auto param : func->getFirstBlock()->getParams()) + { + oldParams.add(param); + } + + List<IRInst*> newParams; + for (auto param : oldParams) + { + auto paramType = param->getFullType(); + auto newParamType = translateToTupleType(builder, paramType); + if (!newParamType) + { + sink->diagnose(param->sourceLoc, Diagnostics::invalidTorchKernelParamType, paramType); + return; + } + auto newParam = builder.emitParam(newParamType); + newParams.add(newParam); + } + + // Convert all new parameters from tuples to their original types. + for (Index i = 0; i < newParams.getCount(); i++) + { + auto oldParam = oldParams[i]; + auto newParam = newParams[i]; + auto convertedParam = makeValueFromTargetTuple(builder, oldParam->getFullType(), newParam); + if (!convertedParam) + { + return; + } + oldParam->replaceUsesWith(convertedParam); + oldParam->removeAndDeallocate(); + } + + auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); + for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) @@ -177,7 +303,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) else if (auto ret = as<IRReturn>(inst)) { builder.setInsertBefore(ret); - auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal()); + auto retVal = makeTargetTuple(builder, ret->getVal()); ret->setOperand(0, retVal); } } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6ce54a948..76e889780 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3705,6 +3705,12 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeTargetTuple, count, args); } + IRInst* IRBuilder::emitTargetTupleGetElement(IRType* elementType, IRInst* targetTupleVal, IRInst* indexVal) + { + IRInst* args[] = {targetTupleVal, indexVal}; + return emitIntrinsicInst(elementType, kIROp_GetTargetTupleElement, 2, args); + } + IRInst* IRBuilder::emitMakeTuple(UInt count, IRInst* const* args) { List<IRType*> types; |
