summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-torch.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-emit-torch.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
-rw-r--r--source/slang/slang-emit-torch.cpp204
1 files changed, 94 insertions, 110 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index bacc9d030..2da2cf79c 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -2,7 +2,6 @@
#include "slang-emit-torch.h"
#include "../core/slang-writer.h"
-
#include "slang-emit-source-writer.h"
#include "slang-mangled-lexer.h"
@@ -26,37 +25,19 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
switch (instType->getOp())
{
- case kIROp_FloatType:
- m_writer->emit("kFloat32");
- break;
- case kIROp_HalfType:
- m_writer->emit("kFloat16");
- break;
- case kIROp_DoubleType:
- m_writer->emit("kFloat64");
- break;
- case kIROp_UInt8Type:
- m_writer->emit("kUInt8");
- break;
- case kIROp_Int8Type:
- m_writer->emit("kInt8");
- break;
- case kIROp_Int16Type:
- m_writer->emit("kInt16");
- break;
- case kIROp_IntType:
- m_writer->emit("kInt32");
- break;
- case kIROp_Int64Type:
- m_writer->emit("kInt64");
- break;
- case kIROp_BoolType:
- m_writer->emit("kBool");
- break;
+ case kIROp_FloatType: m_writer->emit("kFloat32"); break;
+ case kIROp_HalfType: m_writer->emit("kFloat16"); break;
+ case kIROp_DoubleType: m_writer->emit("kFloat64"); break;
+ case kIROp_UInt8Type: m_writer->emit("kUInt8"); break;
+ case kIROp_Int8Type: m_writer->emit("kInt8"); break;
+ case kIROp_Int16Type: m_writer->emit("kInt16"); break;
+ case kIROp_IntType: m_writer->emit("kInt32"); break;
+ case kIROp_Int64Type: m_writer->emit("kInt64"); break;
+ case kIROp_BoolType: m_writer->emit("kBool"); break;
default:
- SLANG_UNEXPECTED((
- std::string("unknown scalar type in allocTorchTensor: ") +
- std::string(getIROpInfo(type->getOp()).name)).c_str());
+ SLANG_UNEXPECTED((std::string("unknown scalar type in allocTorchTensor: ") +
+ std::string(getIROpInfo(type->getOp()).name))
+ .c_str());
break;
}
}
@@ -65,8 +46,7 @@ bool TorchCppSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
{
switch (inst->getOp())
{
- default:
- return false;
+ default: return false;
case kIROp_CudaKernelLaunch:
{
m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel(");
@@ -96,7 +76,7 @@ bool TorchCppSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
m_writer->emit("((cudaStream_t)");
emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
m_writer->emit(")));\n");
-
+
return true;
}
}
@@ -107,100 +87,103 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
switch (inst->getOp())
{
default:
- {
- return Super::tryEmitInstExprImpl(inst, inOuterPrec);
- }
+ {
+ return Super::tryEmitInstExprImpl(inst, inOuterPrec);
+ }
case kIROp_MakeTensorView:
- {
- m_writer->emit("make_tensor_view(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitStringLiteral(getUnmangledName(inst->getOperand(0)));
- m_writer->emit(", ");
- emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType());
- m_writer->emit(", ");
+ {
+ m_writer->emit("make_tensor_view(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitStringLiteral(getUnmangledName(inst->getOperand(0)));
+ m_writer->emit(", ");
+ emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType());
+ m_writer->emit(", ");
- auto tensorViewType = as<IRTensorViewType>(inst->getDataType());
- if (as<IRVectorType>(tensorViewType->getElementType()))
- m_writer->emit("true");
- else
- m_writer->emit("false");
-
- m_writer->emit(")");
- return true;
- }
+ auto tensorViewType = as<IRTensorViewType>(inst->getDataType());
+ if (as<IRVectorType>(tensorViewType->getElementType()))
+ m_writer->emit("true");
+ else
+ m_writer->emit("false");
+
+ m_writer->emit(")");
+ return true;
+ }
case kIROp_TorchGetCudaStream:
- {
- m_writer->emit("at::cuda::getCurrentCUDAStream()");
- return true;
- }
- case kIROp_AllocateTorchTensor:
- {
- if (as<IRTorchTensorType>(inst->getOperand(0)->getDataType()))
{
- /*
- Emit something like:
- ```
- torch::Tensor out = torch::empty_like(other);
- ```
- */
- m_writer->emit("torch::empty_like(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype(");
- emitTorchScalarTypeName(m_writer, inst->getDataType());
- m_writer->emit("))");
+ m_writer->emit("at::cuda::getCurrentCUDAStream()");
+ return true;
}
- else
+ case kIROp_AllocateTorchTensor:
{
- /*
- Emit something like:
- ```
- torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... },
- torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
- ```
- */
- m_writer->emit("torch::empty({ ");
- for (UInt i = 0; i < inst->getOperandCount(); i++)
+ if (as<IRTorchTensorType>(inst->getOperand(0)->getDataType()))
{
- if (i > 0)
- m_writer->emit(", ");
- auto arg = inst->getOperand(i);
- emitOperand(arg, getInfo(EmitOp::General));
+ /*
+ Emit something like:
+ ```
+ torch::Tensor out = torch::empty_like(other);
+ ```
+ */
+ m_writer->emit("torch::empty_like(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
+ m_writer->emit("))");
}
- if (as<IRTorchTensorType>(inst->getDataType()))
+ else
{
- if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0)))
+ /*
+ Emit something like:
+ ```
+ torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... },
+ torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
+ ```
+ */
+ m_writer->emit("torch::empty({ ");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ auto arg = inst->getOperand(i);
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ if (as<IRTorchTensorType>(inst->getDataType()))
{
- // If the element type of the tensor is a vector, we need to add the vector size to the shape.
- m_writer->emit(", ");
- emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0)))
+ {
+ // If the element type of the tensor is a vector, we need to add the vector
+ // size to the shape.
+ m_writer->emit(", ");
+ emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ }
}
+ m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
+ m_writer->emit("))");
}
- m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
- emitTorchScalarTypeName(m_writer, inst->getDataType());
- m_writer->emit("))");
+ return true;
}
- return true;
- }
}
}
-SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out)
+SlangResult TorchCppSourceEmitter::calcTypeName(
+ IRType* type,
+ CodeGenTarget target,
+ StringBuilder& out)
{
switch (type->getOp())
{
- default:
- return Super::calcTypeName(type, target, out);
+ default: return Super::calcTypeName(type, target, out);
case kIROp_TensorViewType:
- {
- out << "TensorView";
- return SLANG_OK;
- }
+ {
+ out << "TensorView";
+ return SLANG_OK;
+ }
case kIROp_TorchTensorType:
- {
- out << "torch::Tensor";
- return SLANG_OK;
- }
+ {
+ out << "torch::Tensor";
+ return SLANG_OK;
+ }
}
}
@@ -214,9 +197,11 @@ void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sin
for (auto inst : module->getGlobalInsts())
{
auto func = as<IRFunc>(inst);
- if (!func) continue;
+ if (!func)
+ continue;
auto decor = func->findDecoration<IRTorchEntryPointDecoration>();
- if (!decor) continue;
+ if (!decor)
+ continue;
m_writer->emit("m.def(");
emitStringLiteral(decor->getFunctionName());
m_writer->emit(", &");
@@ -227,7 +212,6 @@ void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sin
}
m_writer->dedent();
m_writer->emit("}\n");
-
}
} // namespace Slang