summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-28 15:19:03 -0700
committerGitHub <noreply@github.com>2023-03-28 15:19:03 -0700
commita61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch)
tree4fa1a0c6370b8d34262d297653239f48aa004c71 /source
parent8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff)
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude. * more bug fix. * fix. * fix. * More fix. * fix. * f * fix prelude. * update prelude. * update doc * Update prelude. * add zeros_like * update doc. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang44
-rw-r--r--source/slang/slang-emit-c-like.cpp7
-rw-r--r--source/slang/slang-emit-c-like.h1
-rw-r--r--source/slang/slang-emit-torch.cpp141
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp7
5 files changed, 136 insertions, 64 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index d5dc7842a..51cf1cdb7 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -76,6 +76,47 @@ struct TensorView
__target_intrinsic(cuda, "$0.strides[$1]")
[__readNone]
uint stride(uint i);
+
+ __subscript(uint index) -> T
+ {
+ [ForceInline] [__readNone] get { return load(index); }
+ [ForceInline] set { store(index, newValue); }
+ }
+ __subscript(uint i1, uint i2) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i1, i2); }
+ [ForceInline] set { store(i1, i2, newValue); }
+ }
+ __subscript(uint2 i) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i.x, i.y); }
+ [ForceInline] set { store(i.x, i.y, newValue); }
+ }
+ __subscript(uint i1, uint i2, uint i3) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i1, i2, i3); }
+ [ForceInline] set { store(i1, i2, i3, newValue); }
+ }
+ __subscript(uint3 i) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i.x, i.y, i.z); }
+ [ForceInline] set { store(i.x, i.y, i.z, newValue); }
+ }
+ __subscript(uint i1, uint i2, uint i3, uint i4) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i1, i2, i3, i4); }
+ [ForceInline] set { store(i1, i2, i3, i4, newValue); }
+ }
+ __subscript(uint4 i) -> T
+ {
+ [__readNone][ForceInline] get { return load(i.x, i.y, i.z, i.w); }
+ [ForceInline] set { store(i.x, i.y, i.z, i.w, newValue); }
+ }
+ __subscript(uint i1, uint i2, uint i3, uint i4, uint i5) -> T
+ {
+ [ForceInline] [__readNone] get { return load(i1, i2, i3, i4, i5); }
+ [ForceInline] set { store(i1, i2, i3, i4, i5, newValue); }
+ }
}
__generic<T>
@@ -119,6 +160,9 @@ struct TorchTensor
__intrinsic_op($(kIROp_AllocateTorchTensor))
static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> zerosLike(TorchTensor<T> other);
}
__generic<T: IDifferentiable>
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 2bc739142..08ba050db 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -945,6 +945,13 @@ String CLikeSourceEmitter::getName(IRInst* inst)
return name;
}
+String CLikeSourceEmitter::getUnmangledName(IRInst* inst)
+{
+ if (auto nameHintDecor = inst->findDecoration<IRNameHintDecoration>())
+ return nameHintDecor->getName();
+ return getName(inst);
+}
+
void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
{
switch(inst->getOp())
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 9426531a8..8046fa633 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -324,6 +324,7 @@ public:
virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor);
String getName(IRInst* inst);
+ String getUnmangledName(IRInst* inst);
void emitSimpleValue(IRInst* inst) { emitSimpleValueImpl(inst); }
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 877c1dc03..4511039e3 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -24,6 +24,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitStringLiteral(getUnmangledName(inst->getOperand(1)));
m_writer->emit(")");
return true;
}
@@ -67,72 +69,87 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
}
case kIROp_AllocateTorchTensor:
{
- /*
- Emit something like:
- ```
- torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... },
- torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
- ```
- */
- m_writer->emit("torch::empty({ ");
- for (UInt i = 0; i < inst->getOperandCount(); i++)
+ if (as<IRTorchTensorType>(inst->getOperand(0)->getDataType()))
{
- if (i > 0)
- m_writer->emit(", ");
- auto arg = inst->getOperand(i);
- emitOperand(arg, getInfo(EmitOp::General));
+ /*
+ Emit something like:
+ ```
+ torch::Tensor out = torch::zeros_like(other);
+ ```
+ */
+ m_writer->emit("torch::zeros_like(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
}
- m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
-
- // Get the element type of the tensor.
- auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
-
- // If instType is a vector type, then we need to get the element type.
- if (auto vectorType = as<IRVectorType>(instType))
+ else
{
- instType = vectorType->getElementType();
+ /*
+ Emit something like:
+ ```
+ torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... },
+ torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
+ ```
+ */
+ m_writer->emit("torch::empty({ ");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ auto arg = inst->getOperand(i);
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
+
+ // Get the element type of the tensor.
+ auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
+
+ // If instType is a vector type, then we need to get the element type.
+ if (auto vectorType = as<IRVectorType>(instType))
+ {
+ instType = vectorType->getElementType();
+ }
+
+ switch (instType->getOp())
+ {
+ case kIROp_FloatType:
+ m_writer->emit("kFloat32");
+ break;
+ case kIROp_HalfType:
+ m_writer->emit("kFloat16");
+ break;
+ case kIROp_DoubleType:
+ m_writer->emit("kFloat64");
+ break;
+ case kIROp_UInt8Type:
+ m_writer->emit("kUInt8");
+ break;
+ case kIROp_UInt16Type:
+ m_writer->emit("kUInt16");
+ break;
+ case kIROp_UIntType:
+ m_writer->emit("kUInt32");
+ break;
+ case kIROp_UInt64Type:
+ m_writer->emit("kUInt64");
+ break;
+ case kIROp_Int8Type:
+ m_writer->emit("kInt8");
+ break;
+ case kIROp_Int16Type:
+ m_writer->emit("kInt16");
+ break;
+ case kIROp_IntType:
+ m_writer->emit("kInt32");
+ break;
+ case kIROp_Int64Type:
+ m_writer->emit("kInt64");
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
+ break;
+ }
+ m_writer->emit("))");
}
-
- switch (instType->getOp())
- {
- case kIROp_FloatType:
- m_writer->emit("kFloat32");
- break;
- case kIROp_HalfType:
- m_writer->emit("kFloat16");
- break;
- case kIROp_DoubleType:
- m_writer->emit("kFloat64");
- break;
- case kIROp_UInt8Type:
- m_writer->emit("kUInt8");
- break;
- case kIROp_UInt16Type:
- m_writer->emit("kUInt16");
- break;
- case kIROp_UIntType:
- m_writer->emit("kUInt32");
- break;
- case kIROp_UInt64Type:
- m_writer->emit("kUInt64");
- break;
- case kIROp_Int8Type:
- m_writer->emit("kInt8");
- break;
- case kIROp_Int16Type:
- m_writer->emit("kInt16");
- break;
- case kIROp_IntType:
- m_writer->emit("kInt32");
- break;
- case kIROp_Int64Type:
- m_writer->emit("kInt64");
- break;
- default:
- SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
- break;
- }
- m_writer->emit("))");
return true;
}
}
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index eb81bfd8c..971e87a6f 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
return;
}
auto newParam = builder.emitParam(newParamType);
+ param->transferDecorationsTo(newParam);
newParams.add(newParam);
}
@@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
// Remove all [TorchEntryPoint] functions when emitting CUDA source.
void removeTorchKernels(IRModule* module)
{
+ List<IRInst*> toRemove;
for (auto globalInst : module->getGlobalInsts())
{
if (!as<IRFunc>(globalInst))
continue;
if (globalInst->findDecoration<IRTorchEntryPointDecoration>())
- globalInst->removeAndDeallocate();
+ toRemove.add(globalInst);
}
-
+ for (auto inst : toRemove)
+ inst->removeAndDeallocate();
}
}