diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2023-08-01 15:39:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-01 15:39:28 -0400 |
| commit | 1653731718e75c297730dfb878e9f23895d1051d (patch) | |
| tree | 76f05056594f9910e5baf464b05a41e48398fe18 /source | |
| parent | edcc50cdcaf3743d4140b439375d0d40e3a941f7 (diff) | |
Fix literals needing cast (#3039)
* Cast integer literals.
* Fix expected output.
* For CUDA, search global instructions to see what types are used.
Improve lookup for fp16 header in CUDA.
* Fix issue with f16tof32
* Small improvement around finding used base types.
Diffstat (limited to 'source')
| -rw-r--r-- | source/compiler-core/slang-nvrtc-compiler.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 2 |
5 files changed, 79 insertions, 27 deletions
diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index c756955ec..daa392120 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -628,11 +628,35 @@ SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) String libPath = SharedLibraryUtils::getSharedLibraryFileName((void*)m_nvrtcCreateProgram); if (libPath.getLength()) { - const String parentPath = Path::getParentDirectory(libPath); + String parentPath = Path::getParentDirectory(libPath); + if (SLANG_SUCCEEDED(_findFileInIncludePath(parentPath, g_fp16HeaderName, outPath))) { return SLANG_OK; } + + // See if the shared library is in the SDK, as if so we know how to find the includes + // TODO(JS): + // This directory structure is correct for windows perhaps could be different elsewhere. + { + List<UnownedStringSlice> pathSlices; + Path::split(parentPath.getUnownedSlice(), pathSlices); + + // This -2 split holds the version number. + const auto pathSplitCount = pathSlices.getCount(); + if (pathSplitCount >= 3 && + pathSlices[pathSplitCount - 1] == toSlice("bin") && + pathSlices[pathSplitCount - 3] == toSlice("CUDA")) + { + // We want to make sure that one of these paths is CUDA... + const auto sdkPath = Path::getParentDirectory(parentPath); + + if (SLANG_SUCCEEDED(_findFileInIncludePath(sdkPath, g_fp16HeaderName, outPath))) + { + return SLANG_OK; + } + } + } } } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index b05174b0a..adfc98dfd 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1009,13 +1009,6 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) { default: - case BaseType::Int: - { - m_writer->emit("int("); - m_writer->emit(int32_t(litInst->value.intVal)); - m_writer->emit(")"); - return; - } case BaseType::Int8: { m_writer->emit("int8_t("); @@ -1023,6 +1016,14 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) m_writer->emit(")"); return; } + case BaseType::UInt8: + { + m_writer->emit("uint8_t("); + m_writer->emit(UInt(uint8_t(litInst->value.intVal))); + m_writer->emit("U"); + m_writer->emit(")"); + break; + } case BaseType::Int16: { m_writer->emit("int16_t("); @@ -1030,18 +1031,21 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) m_writer->emit(")"); return; } - case BaseType::UInt8: - { - m_writer->emit(UInt(uint8_t(litInst->value.intVal))); - m_writer->emit("U"); - break; - } case BaseType::UInt16: { + m_writer->emit("uint16_t("); m_writer->emit(UInt(uint16_t(litInst->value.intVal))); m_writer->emit("U"); + m_writer->emit(")"); break; } + case BaseType::Int: + { + m_writer->emit("int("); + m_writer->emit(int32_t(litInst->value.intVal)); + m_writer->emit(")"); + return; + } case BaseType::UInt: { m_writer->emit(UInt(uint32_t(litInst->value.intVal))); diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index a6501b5be..345aa3168 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -10,7 +10,29 @@ namespace Slang { +static CUDAExtensionTracker::BaseTypeFlags _findBaseTypesUsed(IRModule* module) +{ + typedef CUDAExtensionTracker::BaseTypeFlags Flags; + + // All basic types are hoistable so must be in global scope. + Flags baseTypesUsed = 0; + + auto moduleInst = module->getModuleInst(); + + // Search all the insts in global scope, for BasicTypes + for (auto inst : moduleInst->getChildren()) + { + if (auto basicType = as<IRBasicType>(inst)) + { + // Get the base type, and set the bit + const auto baseTypeEnum = basicType->getBaseType(); + baseTypesUsed |= Flags(1) << int(baseTypeEnum); + } + } + + return baseTypesUsed; +} void CUDAExtensionTracker::finalize() { @@ -48,12 +70,8 @@ UnownedStringSlice CUDASourceEmitter::getBuiltinTypeName(IROp op) case kIROp_IntPtrType: return UnownedStringSlice("int"); case kIROp_UIntPtrType: return UnownedStringSlice("uint"); #endif - case kIROp_HalfType: - { - m_extensionTracker->requireBaseType(BaseType::Half); - return UnownedStringSlice("__half"); - } - + case kIROp_HalfType: return UnownedStringSlice("__half"); + case kIROp_FloatType: return UnownedStringSlice("float"); case kIROp_DoubleType: return UnownedStringSlice("double"); default: return UnownedStringSlice(); @@ -77,11 +95,7 @@ UnownedStringSlice CUDASourceEmitter::getVectorPrefix(IROp op) case kIROp_UIntType: return UnownedStringSlice("uint"); case kIROp_UInt64Type: return UnownedStringSlice("ulonglong"); - case kIROp_HalfType: - { - m_extensionTracker->requireBaseType(BaseType::Half); - return UnownedStringSlice("__half"); - } + case kIROp_HalfType: return UnownedStringSlice("__half"); case kIROp_FloatType: return UnownedStringSlice("float"); case kIROp_DoubleType: return UnownedStringSlice("double"); @@ -424,8 +438,14 @@ void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operand void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) { - if (targetIntrinsic->getDefinition().startsWith("__half")) + // This works around the problem, where some intrinsics that require the "half" type enabled don't use the half/float16_t type. + // For example `f16tof32` can operate on float16_t *and* uint. If the input is uint, although we are + // using the half feature (as far as CUDA is concerned), the half/float16_t type is not visible/directly used. + if (targetIntrinsic->getDefinition().startsWith(toSlice("__half"))) + { m_extensionTracker->requireBaseType(BaseType::Half); + } + Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); } @@ -795,6 +815,9 @@ bool CUDASourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* v void CUDASourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) { + // Set up with all of the base types used in the module + m_extensionTracker->requireBaseTypes(_findBaseTypesUsed(module)); + CLikeSourceEmitter::emitModuleImpl(module, sink); // Emit all witness table definitions. diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index 47f6eea06..b067e0010 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -17,6 +17,7 @@ public: void requireBaseType(BaseType baseType) { m_baseTypeFlags |= _getFlag(baseType); } bool isBaseTypeRequired(BaseType baseType) { return (m_baseTypeFlags & _getFlag(baseType)) != 0; } + void requireBaseTypes(BaseTypeFlags flags) { m_baseTypeFlags |= flags; } /// Ensure that the generated code is compiled for at least CUDA SM `version` void requireSMVersion(const SemanticVersion& smVersion) { m_smVersion = (smVersion > m_smVersion) ? smVersion : m_smVersion; } diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 1d6fd163c..dbfec9ae7 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -147,7 +147,7 @@ bool opCanBeConstExpr(IRInst* value) // TODO: realistically need to special-case `call` // operations here, so that we check whether the // callee function is fixed/known, and if it is - // whether it has been decoared as constant-foldable + // whether it has been declared as constant-foldable return opCanBeConstExpr(value->getOp()); } |
