diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-28 15:19:03 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-28 15:19:03 -0700 |
| commit | a61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch) | |
| tree | 4fa1a0c6370b8d34262d297653239f48aa004c71 /source | |
| parent | 8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff) | |
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude.
* more bug fix.
* fix.
* fix.
* More fix.
* fix.
* f
* fix prelude.
* update prelude.
* update doc
* Update prelude.
* add zeros_like
* update doc.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 44 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 7 |
5 files changed, 136 insertions, 64 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index d5dc7842a..51cf1cdb7 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -76,6 +76,47 @@ struct TensorView __target_intrinsic(cuda, "$0.strides[$1]") [__readNone] uint stride(uint i); + + __subscript(uint index) -> T + { + [ForceInline] [__readNone] get { return load(index); } + [ForceInline] set { store(index, newValue); } + } + __subscript(uint i1, uint i2) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2); } + [ForceInline] set { store(i1, i2, newValue); } + } + __subscript(uint2 i) -> T + { + [ForceInline] [__readNone] get { return load(i.x, i.y); } + [ForceInline] set { store(i.x, i.y, newValue); } + } + __subscript(uint i1, uint i2, uint i3) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3); } + [ForceInline] set { store(i1, i2, i3, newValue); } + } + __subscript(uint3 i) -> T + { + [ForceInline] [__readNone] get { return load(i.x, i.y, i.z); } + [ForceInline] set { store(i.x, i.y, i.z, newValue); } + } + __subscript(uint i1, uint i2, uint i3, uint i4) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3, i4); } + [ForceInline] set { store(i1, i2, i3, i4, newValue); } + } + __subscript(uint4 i) -> T + { + [__readNone][ForceInline] get { return load(i.x, i.y, i.z, i.w); } + [ForceInline] set { store(i.x, i.y, i.z, i.w, newValue); } + } + __subscript(uint i1, uint i2, uint i3, uint i4, uint i5) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3, i4, i5); } + [ForceInline] set { store(i1, i2, i3, i4, i5, newValue); } + } } __generic<T> @@ -119,6 +160,9 @@ struct TorchTensor __intrinsic_op($(kIROp_AllocateTorchTensor)) static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> zerosLike(TorchTensor<T> other); } __generic<T: IDifferentiable> diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 2bc739142..08ba050db 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -945,6 +945,13 @@ String CLikeSourceEmitter::getName(IRInst* inst) return name; } +String CLikeSourceEmitter::getUnmangledName(IRInst* inst) +{ + if (auto nameHintDecor = inst->findDecoration<IRNameHintDecoration>()) + return nameHintDecor->getName(); + return getName(inst); +} + void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) { switch(inst->getOp()) diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 9426531a8..8046fa633 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -324,6 +324,7 @@ public: virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor); String getName(IRInst* inst); + String getUnmangledName(IRInst* inst); void emitSimpleValue(IRInst* inst) { emitSimpleValueImpl(inst); } diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 877c1dc03..4511039e3 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -24,6 +24,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitStringLiteral(getUnmangledName(inst->getOperand(1))); m_writer->emit(")"); return true; } @@ -67,72 +69,87 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& } case kIROp_AllocateTorchTensor: { - /* - Emit something like: - ``` - torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, - torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - ``` - */ - m_writer->emit("torch::empty({ "); - for (UInt i = 0; i < inst->getOperandCount(); i++) + if (as<IRTorchTensorType>(inst->getOperand(0)->getDataType())) { - if (i > 0) - m_writer->emit(", "); - auto arg = inst->getOperand(i); - emitOperand(arg, getInfo(EmitOp::General)); + /* + Emit something like: + ``` + torch::Tensor out = torch::zeros_like(other); + ``` + */ + m_writer->emit("torch::zeros_like("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); } - m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - - // Get the element type of the tensor. - auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); - - // If instType is a vector type, then we need to get the element type. - if (auto vectorType = as<IRVectorType>(instType)) + else { - instType = vectorType->getElementType(); + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + + // Get the element type of the tensor. + auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as<IRVectorType>(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); } - - switch (instType->getOp()) - { - case kIROp_FloatType: - m_writer->emit("kFloat32"); - break; - case kIROp_HalfType: - m_writer->emit("kFloat16"); - break; - case kIROp_DoubleType: - m_writer->emit("kFloat64"); - break; - case kIROp_UInt8Type: - m_writer->emit("kUInt8"); - break; - case kIROp_UInt16Type: - m_writer->emit("kUInt16"); - break; - case kIROp_UIntType: - m_writer->emit("kUInt32"); - break; - case kIROp_UInt64Type: - m_writer->emit("kUInt64"); - break; - case kIROp_Int8Type: - m_writer->emit("kInt8"); - break; - case kIROp_Int16Type: - m_writer->emit("kInt16"); - break; - case kIROp_IntType: - m_writer->emit("kInt32"); - break; - case kIROp_Int64Type: - m_writer->emit("kInt64"); - break; - default: - SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); - break; - } - m_writer->emit("))"); return true; } } diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index eb81bfd8c..971e87a6f 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) return; } auto newParam = builder.emitParam(newParamType); + param->transferDecorationsTo(newParam); newParams.add(newParam); } @@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) // Remove all [TorchEntryPoint] functions when emitting CUDA source. void removeTorchKernels(IRModule* module) { + List<IRInst*> toRemove; for (auto globalInst : module->getGlobalInsts()) { if (!as<IRFunc>(globalInst)) continue; if (globalInst->findDecoration<IRTorchEntryPointDecoration>()) - globalInst->removeAndDeallocate(); + toRemove.add(globalInst); } - + for (auto inst : toRemove) + inst->removeAndDeallocate(); } } |
