diff options
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 4 |
4 files changed, 38 insertions, 3 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 1bbd42168..0e0349bd7 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -382,6 +382,11 @@ SLANG_MAKE_VECTOR(ulonglong) SLANG_MAKE_VECTOR(__half) #endif +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool1 make_bool1(bool x) { return bool1{ x }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool2 make_bool2(bool x, bool y) { return bool2{ x, y }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z) { return bool3{ x, y, z }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w) { return bool4{ x, y, z, w }; } + #if SLANG_CUDA_RTC #define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\ @@ -408,6 +413,27 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double) SLANG_MAKE_VECTOR_FROM_SCALAR(__half) #endif +#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn,T,N) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##N Fn(T##N* address, T##N val) \ + {\ + T##N result; \ + for (int i = 0; i < N; i++) \ + *_slang_vector_get_element_ptr(&result, i) = Fn(_slang_vector_get_element_ptr(address, i), _slang_vector_get_element(val, i)); \ + return result; \ + }\ + +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 4) template<typename T, int n> struct GetVectorTypeImpl {}; diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index c198b011d..819a6a136 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& auto arg = inst->getOperand(i); emitOperand(arg, getInfo(EmitOp::General)); } + if (as<IRTorchTensorType>(inst->getDataType())) + { + if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0))) + { + // If the element type of the tensor is a vector, we need to add the vector size to the shape. + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + } + } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); emitTorchScalarTypeName(m_writer, inst->getDataType()); m_writer->emit("))"); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 26d00d4df..a19896287 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2807,7 +2807,7 @@ public: IRArrayListType* getArrayListType(IRType* elementType); IRTensorViewType* getTensorViewType(IRType* elementType); - IRTorchTensorType* getTorchTensorType(); + IRTorchTensorType* getTorchTensorType(IRType* elementType); IRDifferentialPairType* getDifferentialPairType( IRType* valueType, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6ca05d2d6..bd2271953 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2837,9 +2837,9 @@ namespace Slang (IRInst**)&elementType); } - IRTorchTensorType* IRBuilder::getTorchTensorType() + IRTorchTensorType* IRBuilder::getTorchTensorType(IRType* elementType) { - return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr); + return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 1, (IRInst**)&elementType); } IRDifferentialPairType* IRBuilder::getDifferentialPairType( |
