summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-09 18:00:48 -0700
committerGitHub <noreply@github.com>2023-05-09 18:00:48 -0700
commitddebd60853b3f34bfd8e89de804fd15808abf75d (patch)
treed5d686843bc2c67e493693376a0170857998c077
parent38ed03a7203baacf36fca62539ac74fd45ed42d2 (diff)
Various fixes for autodiff and slangpy. (#2876)
* Various fixes for autodiff and slangpy. * Fix cuda code gen for `select`. * Fix getBuildTagString(). * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--prelude/slang-cpp-host-prelude.h16
-rw-r--r--prelude/slang-cpp-types-core.h17
-rw-r--r--prelude/slang-cuda-prelude.h32
-rw-r--r--prelude/slang-torch-prelude.h17
-rw-r--r--source/slang/diff.meta.slang18
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit-c-like.cpp1
-rw-r--r--source/slang/slang-emit-cpp.cpp11
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp46
-rw-r--r--source/slang/slang-ir-util.cpp28
-rw-r--r--source/slang/slang.cpp8
11 files changed, 179 insertions, 18 deletions
diff --git a/prelude/slang-cpp-host-prelude.h b/prelude/slang-cpp-host-prelude.h
index 0a026471c..f69d03eed 100644
--- a/prelude/slang-cpp-host-prelude.h
+++ b/prelude/slang-cpp-host-prelude.h
@@ -28,6 +28,22 @@
# include <stdint.h>
#endif // SLANG_LLVM
+#if defined(_MSC_VER)
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
+#else
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
+//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
+#endif
+
+#ifdef __cplusplus
+# define SLANG_PRELUDE_EXTERN_C extern "C"
+# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
+# define SLANG_PRELUDE_EXTERN_C_END }
+#else
+# define SLANG_PRELUDE_EXTERN_C
+# define SLANG_PRELUDE_EXTERN_C_START
+# define SLANG_PRELUDE_EXTERN_C_END
+#endif
#include "slang-cpp-scalar-intrinsics.h"
diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h
index c49ee013c..25fe47202 100644
--- a/prelude/slang-cpp-types-core.h
+++ b/prelude/slang-cpp-types-core.h
@@ -203,6 +203,23 @@ struct Vector<T, 4>
};
template<typename T, int N>
+SLANG_FORCE_INLINE Vector<T, N> _slang_select(Vector<bool, N> condition, Vector<T, N> v0, Vector<T, N> v1)
+{
+ Vector<T, N> result;
+ for (int i = 0; i < N; i++)
+ {
+ result[i] = condition[i] ? v0[i] : v1[i];
+ }
+ return result;
+}
+
+template<typename T>
+SLANG_FORCE_INLINE T _slang_select(bool condition, T v0, T v1)
+{
+ return condition ? v0 : v1;
+}
+
+template<typename T, int N>
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
{
return x[index];
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 0e0349bd7..0362a3a67 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -743,11 +743,43 @@ SLANG_FLOAT_MATRIX_OPS(__half)
#undef SLANG_MATRIX_INT_NEG_OP
#undef SLANG_FLOAT_MATRIX_MOD
+#define SLANG_SELECT_IMPL(T, N)\
+SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select(bool##N condition, Vector<T, N> v0, Vector<T, N> v1) \
+{ \
+ Vector<T, N> result; \
+ for (int i = 0; i < N; i++) \
+ { \
+ *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) ? _slang_vector_get_element(v0, i) : _slang_vector_get_element(v1, i); \
+ } \
+ return result; \
+}
+#define SLANG_SELECT_T(T)\
+ SLANG_SELECT_IMPL(T, 2)\
+ SLANG_SELECT_IMPL(T, 3)\
+ SLANG_SELECT_IMPL(T, 4)
+
+SLANG_SELECT_T(int)
+SLANG_SELECT_T(uint)
+SLANG_SELECT_T(short)
+SLANG_SELECT_T(ushort)
+SLANG_SELECT_T(char)
+SLANG_SELECT_T(uchar)
+SLANG_SELECT_T(float)
+SLANG_SELECT_T(double)
+
+template<typename T>
+SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1)
+{
+ return condition ? v0 : v1;
+}
+
//
// Half support
//
#if SLANG_CUDA_ENABLE_HALF
+SLANG_SELECT_T(__half)
+
// Convenience functions ushort -> half
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i) { return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)); }
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 2f5273e1f..0a69b8305 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -38,6 +38,23 @@
# include <stdint.h>
#endif // SLANG_LLVM
+#if defined(_MSC_VER)
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
+#else
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
+//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
+#endif
+
+#ifdef __cplusplus
+# define SLANG_PRELUDE_EXTERN_C extern "C"
+# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
+# define SLANG_PRELUDE_EXTERN_C_END }
+#else
+# define SLANG_PRELUDE_EXTERN_C
+# define SLANG_PRELUDE_EXTERN_C_START
+# define SLANG_PRELUDE_EXTERN_C_END
+#endif
+
#define SLANG_PRELUDE_NAMESPACE
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 0725103da..f2fd8e3b0 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -246,52 +246,66 @@ __intrinsic_type($(kIROp_TorchTensorType))
struct TorchTensor
{
__intrinsic_op($(kIROp_TorchTensorGetView))
+ [CudaHost]
TensorView<T> getView();
__target_intrinsic(cuda, "$0.dims()")
__target_intrinsic(cpp, "$0.dims()")
[__readNone]
+ [CudaHost]
uint dims();
__target_intrinsic(cuda, "$0.size($1)")
__target_intrinsic(cpp, "$0.size($1)")
[__readNone]
+ [CudaHost]
uint size(uint i);
__target_intrinsic(cuda, "$0.stride($1)")
__target_intrinsic(cpp, "$0.stride($1)")
[__readNone]
+ [CudaHost]
uint stride(uint i);
__target_intrinsic(cuda, "$0.data_ptr<$G0>()")
__target_intrinsic(cpp, "$0.data_ptr<$G0>()")
[__readNone]
+ [CudaHost]
Ptr<T> data_ptr();
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> alloc(uint x);
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> alloc(uint x, uint y);
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> alloc(uint x, uint y, uint z);
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> alloc(uint x, uint y, uint z, uint w);
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4);
__intrinsic_op($(kIROp_AllocateTorchTensor))
+ [CudaHost]
static TorchTensor<T> emptyLike(TorchTensor<T> other);
__target_intrinsic(cpp, "$0.zero_()")
+ [CudaHost]
void fillZero();
__target_intrinsic(cpp, "$0.fill_($1)")
+ [CudaHost]
void fillValue(T val);
+ [CudaHost]
static TorchTensor<T> zerosLike(TorchTensor<T> other)
{
var result = emptyLike(other);
@@ -854,8 +868,10 @@ T detach(T x);
#define SLANG_SQR(x) ((x)*(x))
+#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0))))
+
// Absolute value
-UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut)))
+UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut))
// Saturate
UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut))
// frac
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index a11474c8f..723dc4d64 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -614,6 +614,9 @@ DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can
DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")
+
+DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.")
+
//
// 5xxxx - Target code generation.
//
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 2b5323e3f..75d7b3cf8 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -2215,7 +2215,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
case kIROp_Select:
{
-
auto prec = getInfo(EmitOp::Conditional);
needClose = maybeEmitParens(outerPrec, prec);
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index df6ab4901..49edec92b 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1615,6 +1615,17 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit(")");
return true;
}
+ case kIROp_Select:
+ {
+ m_writer->emit("_slang_select(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
}
}
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d81a33719..6c9280fca 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -195,8 +195,54 @@ public:
}
}
+ bool checkType(IRInst* type)
+ {
+ type = unwrapAttributedType(type);
+ if (as<IRTorchTensorType>(type))
+ return false;
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ return checkType(arrayType->getElementType());
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ if (!checkType(field->getFieldType()))
+ return false;
+ }
+ }
+ return true;
+ }
+ void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst)
+ {
+ auto outerFuncInst = maybeFindOuterGeneric(funcInst);
+
+ if (outerFuncInst->findDecoration<IRCudaHostDecoration>())
+ return;
+ if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>())
+ return;
+
+ // This is a kernel function, we don't allow using TorchTensor type here.
+ for (auto b : funcInst->getBlocks())
+ {
+ for (auto inst : b->getChildren())
+ {
+ if (!checkType(inst->getDataType()))
+ {
+ auto loc = inst->sourceLoc;
+ if (!loc.isValid())
+ loc = funcInst->sourceLoc;
+ sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc);
+ return;
+ }
+
+ }
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
+ checkForInvalidHostTypeUsage(funcInst);
+
if (!_isFuncMarkedForAutoDiff(funcInst))
return;
if (!funcInst->getFirstBlock())
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index ba34a725d..55c2b18c0 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -411,10 +411,6 @@ bool isPtrLikeOrHandleType(IRInst* type)
return true;
if (as<IRPseudoPtrType>(type))
return true;
- if (as<IRMeshOutputType>(type))
- return true;
- if (as<IRHLSLOutputPatchType>(type))
- return true;
switch (type->getOp())
{
case kIROp_ComPtrType:
@@ -824,30 +820,30 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst*
if (!isPtrLikeOrHandleType(type))
return false;
+ if (root)
+ {
+ if (as<IRParameterGroupType>(root->getDataType()))
+ {
+ return false;
+ }
+ }
+
switch (root->getOp())
{
case kIROp_GlobalVar:
- return true;
case kIROp_GlobalParam:
case kIROp_GlobalConstant:
case kIROp_Var:
case kIROp_Param:
break;
+ case kIROp_Call:
+ return true;
default:
- // The inst is defined by an unknown inst.
return true;
}
- if (root)
- {
- if (as<IRParameterGroupType>(root->getDataType()))
- {
- return false;
- }
- auto addrInstParent = getParentFunc(root);
- return (addrInstParent != parentFunc);
- }
- return false;
+ auto addrInstParent = getParentFunc(root);
+ return (addrInstParent != parentFunc);
}
struct GenericChildrenMigrationContextImpl
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index a1c8f9b03..f2598786f 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -125,6 +125,14 @@ namespace Slang {
const char* getBuildTagString()
{
+ if (UnownedStringSlice(SLANG_TAG_VERSION) == "unknown")
+ {
+ // If the tag is unknown, then we will try to get the timestamp of the shared library
+ // and use that as the version string, so that we can at least return something
+ // that uniquely identifies the build.
+ static String timeStampString = String(SharedLibraryUtils::getSharedLibraryTimestamp((void*)spCreateSession));
+ return timeStampString.getBuffer();
+ }
return SLANG_TAG_VERSION;
}