summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-10-05 12:52:49 -0400
committerGitHub <noreply@github.com>2023-10-05 09:52:49 -0700
commit441e13e13f30b96eb04c05725ad7fe1983c92f53 (patch)
treeaee5c31b62876ef8ad60a37b2a4767b6f1a299c6 /source/slang/slang-ir-pytorch-cpp-binding.cpp
parent65751ce222adb302e62b5b7b6312de65638abed5 (diff)
Various AD Fixes (#3263)
* Various fixes * Remove unused parameter * Update slang-ir-loop-unroll.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp110
1 files changed, 95 insertions, 15 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 41665ddf7..3a7e8b9fb 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -177,9 +177,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
List<IRInst*> resultElements;
auto elementType = arrayType->getElementType();
+ auto tupleElementType = translateToTupleType(builder, elementType);
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
@@ -346,7 +347,7 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
List<IRType*> fieldTypes;
for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType(), func));
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
}
auto hostStructType = builder->createStructType();
@@ -358,6 +359,13 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
return hostStructType;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink);
+ if (!elementType)
+ return nullptr;
+ return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount());
+ }
default:
break;
}
@@ -422,13 +430,36 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
}
+ case kIROp_ArrayType:
+ {
+ auto cudaArrayType = cast<IRArrayType>(cudaType);
+ auto hostArrayType = cast<IRArrayType>(hostType);
+
+ List<IRInst*> resultElements;
+ for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++)
+ {
+ auto cudaElementType = cudaArrayType->getElementType();
+ auto hostElementType = hostArrayType->getElementType();
+ auto castedElement = castHostToCUDAType(
+ builder,
+ hostElementType,
+ cudaElementType,
+ builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)));
+
+ SLANG_RELEASE_ASSERT(castedElement);
+ resultElements.add(castedElement);
+ }
+
+ return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
default:
break;
}
- // If translateToHostType worked correctly, we shouldn't get here.
- SLANG_UNREACHABLE("unhandled type");
+ // If translateToHostType worked correctly, there should be no unhandled cases here.
+ // However, we won't diagnose here since its already diagnosed in translateToHostType()
+ return nullptr;
}
void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc)
@@ -553,6 +584,12 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
+
+ if (!type || sink->getErrorCount() > 0)
+ {
+ return nullptr;
+ }
+
auto hostParam = builder->emitParam(type);
// Add a namehint to the param by appending the suffix "_host".
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
@@ -600,6 +637,38 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
}
return;
}
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ IRBuilder builder(arrayType->getModule());
+ if (!arrayType->findDecoration<IRPyExportDecoration>())
+ builder.addPyExportDecoration(arrayType, UnownedStringSlice("Array"));
+
+ markTypeForPyExport(arrayType->getElementType(), sink);
+ return;
+ }
+}
+
+String tryGetExportTypeName(IRBuilder* builder, IRType* type)
+{
+ if (auto structType = as<IRStructType>(type))
+ {
+ if (auto pyExportDecoration = type->findDecoration<IRPyExportDecoration>())
+ return String(pyExportDecoration->getExportName());
+ else
+ return String("");
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << "Array_";
+ nameBuilder << tryGetExportTypeName(builder, arrayType->getElementType());
+ nameBuilder << "_";
+ nameBuilder << cast<IRIntLit>(arrayType->getElementCount())->getValue();
+
+ return nameBuilder.produceString();
+ }
+ else
+ return String();
}
void generateReflectionForType(IRType* type, DiagnosticSink* sink)
@@ -609,7 +678,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// The list will contain the names of all the fields of the type.
//
- // TODO: Fix this to avoid emitting the same type reflection multiple times.
if (!type->findDecoration<IRPyExportDecoration>())
return;
@@ -635,20 +703,32 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
else
fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- if (!field->getFieldType()->findDecoration<IRPyExportDecoration>())
- {
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- continue;
- }
-
auto fieldType = field->getFieldType();
+ auto exportName = tryGetExportTypeName(&builder, fieldType);
- fieldTypeNames.add(
- builder.emitGetNativeString(
- builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName())));
+ if (exportName.getLength() > 0)
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice())));
+ else
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
}
break;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = as<IRArrayType>(type)->getElementType();
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type"))));
+ fieldTypeNames.add(
+ builder.emitGetNativeString(
+ builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice())));
+
+ auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount());
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size"))));
+
+ StringBuilder elementCountStr;
+ elementCountStr << elementCount->getValue();
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice())));
+ break;
+ }
default:
break;
}
@@ -676,7 +756,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// Set function name.
StringBuilder reflFuncExportName;
- reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName();
+ reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice();
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());