summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-torch-prelude.h7
-rw-r--r--source/slang/slang-emit-torch.cpp109
2 files changed, 64 insertions, 52 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 4844e9248..70c516a3a 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator
}
};
-TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name)
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
+ // Convert device and scalar types.
if (!val.device().is_cuda())
val = val.to(torch::kCUDA);
-
+ if (val.dtype() != targetScalarType)
+ val = val.to(targetScalarType);
+
TensorView res = {};
res.dimensionCount = val.dim();
res.strides = allocator->allocUIntArray(val.dim());
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 4511039e3..276164ed5 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -10,6 +10,61 @@
namespace Slang
{
+
+void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
+{
+ m_writer->emit("torch::");
+
+ // Get the element type of the tensor.
+ auto instType = as<IRTorchTensorType>(type)->getOperand(0);
+
+ // If instType is a vector type, then we need to get the element type.
+ if (auto vectorType = as<IRVectorType>(instType))
+ {
+ instType = vectorType->getElementType();
+ }
+
+ switch (instType->getOp())
+ {
+ case kIROp_FloatType:
+ m_writer->emit("kFloat32");
+ break;
+ case kIROp_HalfType:
+ m_writer->emit("kFloat16");
+ break;
+ case kIROp_DoubleType:
+ m_writer->emit("kFloat64");
+ break;
+ case kIROp_UInt8Type:
+ m_writer->emit("kUInt8");
+ break;
+ case kIROp_UInt16Type:
+ m_writer->emit("kUInt16");
+ break;
+ case kIROp_UIntType:
+ m_writer->emit("kUInt32");
+ break;
+ case kIROp_UInt64Type:
+ m_writer->emit("kUInt64");
+ break;
+ case kIROp_Int8Type:
+ m_writer->emit("kInt8");
+ break;
+ case kIROp_Int16Type:
+ m_writer->emit("kInt16");
+ break;
+ case kIROp_IntType:
+ m_writer->emit("kInt32");
+ break;
+ case kIROp_Int64Type:
+ m_writer->emit("kInt64");
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
+ break;
+ }
+}
+
bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -26,6 +81,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitStringLiteral(getUnmangledName(inst->getOperand(1)));
+ m_writer->emit(", ");
+ emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType());
m_writer->emit(")");
return true;
}
@@ -98,56 +155,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
auto arg = inst->getOperand(i);
emitOperand(arg, getInfo(EmitOp::General));
}
- m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
-
- // Get the element type of the tensor.
- auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
-
- // If instType is a vector type, then we need to get the element type.
- if (auto vectorType = as<IRVectorType>(instType))
- {
- instType = vectorType->getElementType();
- }
-
- switch (instType->getOp())
- {
- case kIROp_FloatType:
- m_writer->emit("kFloat32");
- break;
- case kIROp_HalfType:
- m_writer->emit("kFloat16");
- break;
- case kIROp_DoubleType:
- m_writer->emit("kFloat64");
- break;
- case kIROp_UInt8Type:
- m_writer->emit("kUInt8");
- break;
- case kIROp_UInt16Type:
- m_writer->emit("kUInt16");
- break;
- case kIROp_UIntType:
- m_writer->emit("kUInt32");
- break;
- case kIROp_UInt64Type:
- m_writer->emit("kUInt64");
- break;
- case kIROp_Int8Type:
- m_writer->emit("kInt8");
- break;
- case kIROp_Int16Type:
- m_writer->emit("kInt16");
- break;
- case kIROp_IntType:
- m_writer->emit("kInt32");
- break;
- case kIROp_Int64Type:
- m_writer->emit("kInt64");
- break;
- default:
- SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
- break;
- }
+ m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
m_writer->emit("))");
}
return true;