summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h26
-rw-r--r--source/slang/slang-emit-torch.cpp9
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp4
4 files changed, 38 insertions, 3 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 1bbd42168..0e0349bd7 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -382,6 +382,11 @@ SLANG_MAKE_VECTOR(ulonglong)
SLANG_MAKE_VECTOR(__half)
#endif
+SLANG_FORCE_INLINE SLANG_CUDA_CALL bool1 make_bool1(bool x) { return bool1{ x }; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL bool2 make_bool2(bool x, bool y) { return bool2{ x, y }; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z) { return bool3{ x, y, z }; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w) { return bool4{ x, y, z, w }; }
+
#if SLANG_CUDA_RTC
#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\
@@ -408,6 +413,27 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double)
SLANG_MAKE_VECTOR_FROM_SCALAR(__half)
#endif
+#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn,T,N) \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##N Fn(T##N* address, T##N val) \
+ {\
+ T##N result; \
+ for (int i = 0; i < N; i++) \
+ *_slang_vector_get_element_ptr(&result, i) = Fn(_slang_vector_get_element_ptr(address, i), _slang_vector_get_element(val, i)); \
+ return result; \
+ }\
+
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 2)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 3)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 4)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 2)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 3)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 4)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 2)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 3)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 4)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 2)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 3)
+SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 4)
template<typename T, int n>
struct GetVectorTypeImpl {};
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index c198b011d..819a6a136 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
auto arg = inst->getOperand(i);
emitOperand(arg, getInfo(EmitOp::General));
}
+ if (as<IRTorchTensorType>(inst->getDataType()))
+ {
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0)))
+ {
+ // If the element type of the tensor is a vector, we need to add the vector size to the shape.
+ m_writer->emit(", ");
+ emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ }
+ }
m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
emitTorchScalarTypeName(m_writer, inst->getDataType());
m_writer->emit("))");
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 26d00d4df..a19896287 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2807,7 +2807,7 @@ public:
IRArrayListType* getArrayListType(IRType* elementType);
IRTensorViewType* getTensorViewType(IRType* elementType);
- IRTorchTensorType* getTorchTensorType();
+ IRTorchTensorType* getTorchTensorType(IRType* elementType);
IRDifferentialPairType* getDifferentialPairType(
IRType* valueType,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6ca05d2d6..bd2271953 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2837,9 +2837,9 @@ namespace Slang
(IRInst**)&elementType);
}
- IRTorchTensorType* IRBuilder::getTorchTensorType()
+ IRTorchTensorType* IRBuilder::getTorchTensorType(IRType* elementType)
{
- return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr);
+ return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 1, (IRInst**)&elementType);
}
IRDifferentialPairType* IRBuilder::getDifferentialPairType(