diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-01 14:50:04 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-01 14:50:04 -0400 |
| commit | 6e550430cf24a6324ce22f821181f2cab904d543 (patch) | |
| tree | f0552bbbb4a4f3c2cb503355d842baf6a7b8753d | |
| parent | 0e71a6d40d2ccdc9e6bb861e7bbdb9479dbec636 (diff) | |
Error out when constructing tensor views from tensors with 0 stride. (#4516)
This avoids a problem with broadcasted tensors. Our tensor-view platform is designed to allow unrestricted access to tensor memory, while broadcasted tensors were designed for 'read-only' use-cases. Trying to write into a broadcasted tensor needs re-allocation, which Slang is not designed to do.
For now, we enforce contiguity on tensors with any 0 strides.
In the future, we will introduce a ConstTensorView object to allow such tensors to be used as an input.
This patch also propagates name-hint information through structs & arrays of tensors, to allow sensible names for the error messages (before this the error messages were temporary inst numbers, which is nearly impossible to debug)
| -rw-r--r-- | prelude/slang-torch-prelude.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 97 |
2 files changed, 93 insertions, 7 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index bdba620fe..28548e48e 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -140,6 +140,9 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy for (int i = 0; i < val.dim(); ++i) { res.strides[i] = val.stride(i) * elementSize; + if (res.strides[i] == 0) + throw std::runtime_error(std::string(name).append(": tensors with broadcasted dimensions are not supported (use tensor.contiguous() to make tensor whole)").c_str()); + res.sizes[i] = val.size(i); if (res.sizes[i] > 0) isEmpty = false; diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 6922984d6..2c91dc3b6 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -262,6 +262,15 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i)); + + // Make a name hint: <valname>_<i> + if (auto nameHint = val->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << nameHint->getName() << "_" << i; + builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice()); + } + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; @@ -275,7 +284,22 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst IRIntegerValue i = 0; for (auto field : structType->getFields()) { - auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement( + translateToTupleType(builder, field->getFieldType()), + val, + builder.getIntValue(builder.getIntType(), i)); + + // Make a name hint: <valname>_<fieldname> + if (auto nameHint = val->findDecoration<IRNameHintDecoration>()) + { + if (auto fieldHint = field->getKey()->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << nameHint->getName() << "_" << fieldHint->getName(); + builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice()); + } + } + auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; @@ -414,16 +438,21 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag { // Create a new struct type with translated fields. List<IRType*> fieldTypes; + List<IRNameHintDecoration*> fieldNames; for (auto field : as<IRStructType>(type)->getFields()) { fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink)); + fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>()); } auto hostStructType = builder->createStructType(); // Add fields to the struct. for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++) { - builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]); + auto structKey = builder->createStructKey(); + if (fieldNames[i]) + builder->addNameHintDecoration(structKey, fieldNames[i]->getName()); + builder->createStructField(hostStructType, structKey, fieldTypes[i]); } return hostStructType; @@ -444,6 +473,43 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag return nullptr; } +// Propagates name hints through field extracts. +IRInst* propagateNameHint(IRBuilder* builder, IRFieldExtract* inst) +{ + // If the field has a name hint, propagate it to the inst by appending the field name to the inst name + // (which must be fetched from the inst's name hint decoration). + // + // This is useful for propagating the name hint from a struct field to the inst that extracts the field. + if (auto nameHint = inst->getField()->findDecoration<IRNameHintDecoration>()) + { + if (auto instNameHint = inst->getBase()->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << instNameHint->getName() << "_" << nameHint->getName(); + builder->addNameHintDecoration(inst, newName.getUnownedSlice()); + } + } + + return inst; +} + +// Propagates name hints through array indexing +IRInst* propagateNameHint(IRBuilder* builder, IRGetElement* inst) +{ + // If the index is a constant, we can propagate the name hint from the inst to the index. + if (auto intLit = as<IRIntLit>(inst->getIndex())) + { + if (auto nameHint = inst->getBase()->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << nameHint->getName() << "_" << intLit->getValue(); + builder->addNameHintDecoration(inst, newName.getUnownedSlice()); + } + } + + return inst; +} + IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) { if (hostType == cudaType) @@ -479,7 +545,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp builder, hostFieldType, cudaFieldType, - builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())); + propagateNameHint( + builder, + cast<IRFieldExtract>(builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())))); SLANG_RELEASE_ASSERT(castedField); resultFields.add(castedField); @@ -501,7 +569,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp builder, hostElementType, cudaElementType, - builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i))); + propagateNameHint( + builder, + cast<IRGetElement>(builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i))))); SLANG_RELEASE_ASSERT(castedElement); resultElements.add(castedElement); @@ -647,7 +717,8 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno } auto hostParam = builder->emitParam(type); - // Add a namehint to the param by appending the suffix "_host". + + // Add a namehint to the param if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) { builder->addNameHintDecoration(hostParam, nameHint->getName()); @@ -1219,7 +1290,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) List<IRInst*> params; for (auto param : func->getFirstBlock()->getParams()) { - params.add(builder.emitParam(param->getFullType())); + auto newParam = builder.emitParam(param->getFullType()); + + // Copy over the name hint. + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + builder.addNameHintDecoration(newParam, nameHint->getName()); + + params.add(newParam); } wrapperFunc->setFullType(func->getFullType()); @@ -1277,7 +1354,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) List<IRInst*> params; for (auto param : func->getFirstBlock()->getParams()) { - params.add(builder.emitParam(param->getFullType())); + auto newParam = builder.emitParam(param->getFullType()); + + // Copy over the name hint. + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + builder.addNameHintDecoration(newParam, nameHint->getName()); + + params.add(newParam); } wrapperFunc->setFullType(func->getFullType()); |
