diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 18 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 46 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 8 |
7 files changed, 97 insertions, 18 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 0725103da..f2fd8e3b0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -246,52 +246,66 @@ __intrinsic_type($(kIROp_TorchTensorType)) struct TorchTensor { __intrinsic_op($(kIROp_TorchTensorGetView)) + [CudaHost] TensorView<T> getView(); __target_intrinsic(cuda, "$0.dims()") __target_intrinsic(cpp, "$0.dims()") [__readNone] + [CudaHost] uint dims(); __target_intrinsic(cuda, "$0.size($1)") __target_intrinsic(cpp, "$0.size($1)") [__readNone] + [CudaHost] uint size(uint i); __target_intrinsic(cuda, "$0.stride($1)") __target_intrinsic(cpp, "$0.stride($1)") [__readNone] + [CudaHost] uint stride(uint i); __target_intrinsic(cuda, "$0.data_ptr<$G0>()") __target_intrinsic(cpp, "$0.data_ptr<$G0>()") [__readNone] + [CudaHost] Ptr<T> data_ptr(); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y, uint z); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y, uint z, uint w); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> emptyLike(TorchTensor<T> other); __target_intrinsic(cpp, "$0.zero_()") + [CudaHost] void fillZero(); __target_intrinsic(cpp, "$0.fill_($1)") + [CudaHost] void fillValue(T val); + [CudaHost] static TorchTensor<T> zerosLike(TorchTensor<T> other) { var result = emptyLike(other); @@ -854,8 +868,10 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) +#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) + // Absolute value -UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index a11474c8f..723dc4d64 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -614,6 +614,9 @@ DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.") DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.") + +DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") + // // 5xxxx - Target code generation. // diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 2b5323e3f..75d7b3cf8 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2215,7 +2215,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_Select: { - auto prec = getInfo(EmitOp::Conditional); needClose = maybeEmitParens(outerPrec, prec); diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index df6ab4901..49edec92b 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1615,6 +1615,17 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut m_writer->emit(")"); return true; } + case kIROp_Select: + { + m_writer->emit("_slang_select("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } } } diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d81a33719..6c9280fca 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -195,8 +195,54 @@ public: } } + bool checkType(IRInst* type) + { + type = unwrapAttributedType(type); + if (as<IRTorchTensorType>(type)) + return false; + else if (auto arrayType = as<IRArrayTypeBase>(type)) + return checkType(arrayType->getElementType()); + else if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + if (!checkType(field->getFieldType())) + return false; + } + } + return true; + } + void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst) + { + auto outerFuncInst = maybeFindOuterGeneric(funcInst); + + if (outerFuncInst->findDecoration<IRCudaHostDecoration>()) + return; + if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>()) + return; + + // This is a kernel function, we don't allow using TorchTensor type here. + for (auto b : funcInst->getBlocks()) + { + for (auto inst : b->getChildren()) + { + if (!checkType(inst->getDataType())) + { + auto loc = inst->sourceLoc; + if (!loc.isValid()) + loc = funcInst->sourceLoc; + sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc); + return; + } + + } + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { + checkForInvalidHostTypeUsage(funcInst); + if (!_isFuncMarkedForAutoDiff(funcInst)) return; if (!funcInst->getFirstBlock()) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index ba34a725d..55c2b18c0 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -411,10 +411,6 @@ bool isPtrLikeOrHandleType(IRInst* type) return true; if (as<IRPseudoPtrType>(type)) return true; - if (as<IRMeshOutputType>(type)) - return true; - if (as<IRHLSLOutputPatchType>(type)) - return true; switch (type->getOp()) { case kIROp_ComPtrType: @@ -824,30 +820,30 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* if (!isPtrLikeOrHandleType(type)) return false; + if (root) + { + if (as<IRParameterGroupType>(root->getDataType())) + { + return false; + } + } + switch (root->getOp()) { case kIROp_GlobalVar: - return true; case kIROp_GlobalParam: case kIROp_GlobalConstant: case kIROp_Var: case kIROp_Param: break; + case kIROp_Call: + return true; default: - // The inst is defined by an unknown inst. return true; } - if (root) - { - if (as<IRParameterGroupType>(root->getDataType())) - { - return false; - } - auto addrInstParent = getParentFunc(root); - return (addrInstParent != parentFunc); - } - return false; + auto addrInstParent = getParentFunc(root); + return (addrInstParent != parentFunc); } struct GenericChildrenMigrationContextImpl diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index a1c8f9b03..f2598786f 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -125,6 +125,14 @@ namespace Slang { const char* getBuildTagString() { + if (UnownedStringSlice(SLANG_TAG_VERSION) == "unknown") + { + // If the tag is unknown, then we will try to get the timestamp of the shared library + // and use that as the version string, so that we can at least return something + // that uniquely identifies the build. + static String timeStampString = String(SharedLibraryUtils::getSharedLibraryTimestamp((void*)spCreateSession)); + return timeStampString.getBuffer(); + } return SLANG_TAG_VERSION; } |
