summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp97
1 files changed, 90 insertions, 7 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 6922984d6..2c91dc3b6 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -262,6 +262,15 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
+
+ // Make a name hint: <valname>_<i>
+ if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << i;
+ builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice());
+ }
+
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
@@ -275,7 +284,22 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
IRIntegerValue i = 0;
for (auto field : structType->getFields())
{
- auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(
+ translateToTupleType(builder, field->getFieldType()),
+ val,
+ builder.getIntValue(builder.getIntType(), i));
+
+ // Make a name hint: <valname>_<fieldname>
+ if (auto nameHint = val->findDecoration<IRNameHintDecoration>())
+ {
+ if (auto fieldHint = field->getKey()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << fieldHint->getName();
+ builder.addNameHintDecoration(tupleElement, newName.getUnownedSlice());
+ }
+ }
+
auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
if (!convertedElement)
return nullptr;
@@ -414,16 +438,21 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
{
// Create a new struct type with translated fields.
List<IRType*> fieldTypes;
+ List<IRNameHintDecoration*> fieldNames;
for (auto field : as<IRStructType>(type)->getFields())
{
fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
+ fieldNames.add(field->getKey()->findDecoration<IRNameHintDecoration>());
}
auto hostStructType = builder->createStructType();
// Add fields to the struct.
for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
{
- builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]);
+ auto structKey = builder->createStructKey();
+ if (fieldNames[i])
+ builder->addNameHintDecoration(structKey, fieldNames[i]->getName());
+ builder->createStructField(hostStructType, structKey, fieldTypes[i]);
}
return hostStructType;
@@ -444,6 +473,43 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
return nullptr;
}
+// Propagates name hints through field extracts.
+IRInst* propagateNameHint(IRBuilder* builder, IRFieldExtract* inst)
+{
+ // If the field has a name hint, propagate it to the inst by appending the field name to the inst name
+ // (which must be fetched from the inst's name hint decoration).
+ //
+ // This is useful for propagating the name hint from a struct field to the inst that extracts the field.
+ if (auto nameHint = inst->getField()->findDecoration<IRNameHintDecoration>())
+ {
+ if (auto instNameHint = inst->getBase()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << instNameHint->getName() << "_" << nameHint->getName();
+ builder->addNameHintDecoration(inst, newName.getUnownedSlice());
+ }
+ }
+
+ return inst;
+}
+
+// Propagates name hints through array indexing
+IRInst* propagateNameHint(IRBuilder* builder, IRGetElement* inst)
+{
+ // If the index is a constant, we can propagate the name hint from the inst to the index.
+ if (auto intLit = as<IRIntLit>(inst->getIndex()))
+ {
+ if (auto nameHint = inst->getBase()->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_" << intLit->getValue();
+ builder->addNameHintDecoration(inst, newName.getUnownedSlice());
+ }
+ }
+
+ return inst;
+}
+
IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
{
if (hostType == cudaType)
@@ -479,7 +545,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
builder,
hostFieldType,
cudaFieldType,
- builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()));
+ propagateNameHint(
+ builder,
+ cast<IRFieldExtract>(builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()))));
SLANG_RELEASE_ASSERT(castedField);
resultFields.add(castedField);
@@ -501,7 +569,9 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
builder,
hostElementType,
cudaElementType,
- builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)));
+ propagateNameHint(
+ builder,
+ cast<IRGetElement>(builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)))));
SLANG_RELEASE_ASSERT(castedElement);
resultElements.add(castedElement);
@@ -647,7 +717,8 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
}
auto hostParam = builder->emitParam(type);
- // Add a namehint to the param by appending the suffix "_host".
+
+ // Add a namehint to the param
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
{
builder->addNameHintDecoration(hostParam, nameHint->getName());
@@ -1219,7 +1290,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
List<IRInst*> params;
for (auto param : func->getFirstBlock()->getParams())
{
- params.add(builder.emitParam(param->getFullType()));
+ auto newParam = builder.emitParam(param->getFullType());
+
+ // Copy over the name hint.
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ builder.addNameHintDecoration(newParam, nameHint->getName());
+
+ params.add(newParam);
}
wrapperFunc->setFullType(func->getFullType());
@@ -1277,7 +1354,13 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
List<IRInst*> params;
for (auto param : func->getFirstBlock()->getParams())
{
- params.add(builder.emitParam(param->getFullType()));
+ auto newParam = builder.emitParam(param->getFullType());
+
+ // Copy over the name hint.
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ builder.addNameHintDecoration(newParam, nameHint->getName());
+
+ params.add(newParam);
}
wrapperFunc->setFullType(func->getFullType());