summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-01 14:50:04 -0400
committerGitHub <noreply@github.com>2024-07-01 14:50:04 -0400
commit6e550430cf24a6324ce22f821181f2cab904d543 (patch)
treef0552bbbb4a4f3c2cb503355d842baf6a7b8753d
parent0e71a6d40d2ccdc9e6bb861e7bbdb9479dbec636 (diff)
Error out when constructing tensor views from tensors with 0 stride. (#4516)
This avoids a problem with broadcasted tensors. Our tensor-view platform is designed to allow unrestricted access to tensor memory, while broadcasted tensors were designed for 'read-only' use-cases. Trying to write into a broadcasted tensor needs re-allocation, which Slang is not designed to do. For now, we enforce contiguity on tensors with any 0 strides. In the future, we will introduce a ConstTensorView object to allow such tensors to be used as an input. This patch also propagates name-hint information through structs & arrays of tensors, to allow sensible names for the error messages (before this the error messages were temporary inst numbers, which is nearly impossible to debug)
-rw-r--r--prelude/slang-torch-prelude.h3
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp97
2 files changed, 93 insertions, 7 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index bdba620fe..28548e48e 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -140,6 +140,9 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
for (int i = 0; i < val.dim(); ++i)
{
res.strides[i] = val.stride(i) * elementSize;
+ if (res.strides[i] == 0)
+ throw std::runtime_error(std::string(name).append(": tensors with broadcasted dimensions are not supported (use tensor.contiguous() to make tensor whole)").c_str());
+
res.sizes[i] = val.size(i);
if (res.sizes[i] > 0)
isEmpty = false;
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 6922984d6..2c91dc3b6 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -262,6 +262,15 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
+
+ // Make a name hint: <valname>_<i>
+ if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << i;
+ builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice());
+ }
+
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
@@ -275,7 +284,22 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
IRIntegerValue i = 0;
for (auto field : structType->getFields())
{
- auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(
+ translateToTupleType(builder, field->getFieldType()),
+ val,
+ builder.getIntValue(builder.getIntType(), i));
+
+ // Make a name hint: <valname>_<fieldname>
+ if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
+ {
+ if (auto fieldHint = field->getKey()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << fieldHint->getName();
+ builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice());
+ }
+ }
+
auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
if (!convertedElement)
return nullptr;
@@ -414,16 +438,21 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
{
// Create a new struct type with translated fields.
List<IRType*> fieldTypes;
+ List<IRNameHintDecoration*> fieldNames;
for (auto field : as<IRStructType>(type)->getFields())
{
fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
+ fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>());
}
auto hostStructType = builder->createStructType();
// Add fields to the struct.
for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
{
- builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]);
+ auto structKey = builder->createStructKey();
+ if (fieldNames[i])
+ builder->addNameHintDecoration(structKey, fieldNames[i]->getName());
+ builder->createStructField(hostStructType, structKey, fieldTypes[i]);
}
return hostStructType;
@@ -444,6 +473,43 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
return nullptr;
}
+// Propagates name hints through field extracts.
+IRInst* propagateNameHint(IRBuilder* builder, IRFieldExtract* inst)
+{
+ // If the field has a name hint, propagate it to the inst by appending the field name to the inst name
+ // (which must be fetched from the inst's name hint decoration).
+ //
+ // This is useful for propagating the name hint from a struct field to the inst that extracts the field.
+ if (auto nameHint = inst->getField()->findDecoration<IRNameHintDecoration>())
+ {
+ if (auto instNameHint = inst->getBase()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << instNameHint->getName() << "_" << nameHint->getName();
+ builder->addNameHintDecoration(inst, newName.getUnownedSlice());
+ }
+ }
+
+ return inst;
+}
+
+// Propagates name hints through array indexing
+IRInst* propagateNameHint(IRBuilder* builder, IRGetElement* inst)
+{
+ // If the index is a constant, we can propagate the name hint from the inst to the index.
+ if (auto intLit = as<IRIntLit>(inst->getIndex()))
+ {
+ if (auto nameHint = inst->getBase()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << intLit->getValue();
+ builder->addNameHintDecoration(inst, newName.getUnownedSlice());
+ }
+ }
+
+ return inst;
+}
+
IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
{
if (hostType == cudaType)
@@ -479,7 +545,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
builder,
hostFieldType,
cudaFieldType,
- builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()));
+ propagateNameHint(
+ builder,
+ cast<IRFieldExtract>(builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()))));
SLANG_RELEASE_ASSERT(castedField);
resultFields.add(castedField);
@@ -501,7 +569,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
builder,
hostElementType,
cudaElementType,
- builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)));
+ propagateNameHint(
+ builder,
+ cast<IRGetElement>(builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)))));
SLANG_RELEASE_ASSERT(castedElement);
resultElements.add(castedElement);
@@ -647,7 +717,8 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
}
auto hostParam = builder->emitParam(type);
- // Add a namehint to the param by appending the suffix "_host".
+
+ // Add a namehint to the param
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
{
builder->addNameHintDecoration(hostParam, nameHint->getName());
@@ -1219,7 +1290,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
List<IRInst*> params;
for (auto param : func->getFirstBlock()->getParams())
{
- params.add(builder.emitParam(param->getFullType()));
+ auto newParam = builder.emitParam(param->getFullType());
+
+ // Copy over the name hint.
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ builder.addNameHintDecoration(newParam, nameHint->getName());
+
+ params.add(newParam);
}
wrapperFunc->setFullType(func->getFullType());
@@ -1277,7 +1354,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
List<IRInst*> params;
for (auto param : func->getFirstBlock()->getParams())
{
- params.add(builder.emitParam(param->getFullType()));
+ auto newParam = builder.emitParam(param->getFullType());
+
+ // Copy over the name hint.
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ builder.addNameHintDecoration(newParam, nameHint->getName());
+
+ params.add(newParam);
}
wrapperFunc->setFullType(func->getFullType());