summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2023-08-01 15:39:28 -0400
committerGitHub <noreply@github.com>2023-08-01 15:39:28 -0400
commit1653731718e75c297730dfb878e9f23895d1051d (patch)
tree76f05056594f9910e5baf464b05a41e48398fe18 /source
parentedcc50cdcaf3743d4140b439375d0d40e3a941f7 (diff)
Fix literals needing cast (#3039)
* Cast integer literals. * Fix expected output. * For CUDA, search global instructions to see what types are used. Improve lookup for fp16 header in CUDA. * Fix issue with f16tof32 * Small improvement around finding used base types.
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-nvrtc-compiler.cpp26
-rw-r--r--source/slang/slang-emit-c-like.cpp30
-rw-r--r--source/slang/slang-emit-cuda.cpp47
-rw-r--r--source/slang/slang-emit-cuda.h1
-rw-r--r--source/slang/slang-ir-constexpr.cpp2
5 files changed, 79 insertions, 27 deletions
diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp
index c756955ec..daa392120 100644
--- a/source/compiler-core/slang-nvrtc-compiler.cpp
+++ b/source/compiler-core/slang-nvrtc-compiler.cpp
@@ -628,11 +628,35 @@ SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath)
String libPath = SharedLibraryUtils::getSharedLibraryFileName((void*)m_nvrtcCreateProgram);
if (libPath.getLength())
{
- const String parentPath = Path::getParentDirectory(libPath);
+ String parentPath = Path::getParentDirectory(libPath);
+
if (SLANG_SUCCEEDED(_findFileInIncludePath(parentPath, g_fp16HeaderName, outPath)))
{
return SLANG_OK;
}
+
+ // See if the shared library is in the SDK, as if so we know how to find the includes
+ // TODO(JS):
+ // This directory structure is correct for windows perhaps could be different elsewhere.
+ {
+ List<UnownedStringSlice> pathSlices;
+ Path::split(parentPath.getUnownedSlice(), pathSlices);
+
+ // This -2 split holds the version number.
+ const auto pathSplitCount = pathSlices.getCount();
+ if (pathSplitCount >= 3 &&
+ pathSlices[pathSplitCount - 1] == toSlice("bin") &&
+ pathSlices[pathSplitCount - 3] == toSlice("CUDA"))
+ {
+ // We want to make sure that one of these paths is CUDA...
+ const auto sdkPath = Path::getParentDirectory(parentPath);
+
+ if (SLANG_SUCCEEDED(_findFileInIncludePath(sdkPath, g_fp16HeaderName, outPath)))
+ {
+ return SLANG_OK;
+ }
+ }
+ }
}
}
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index b05174b0a..adfc98dfd 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1009,13 +1009,6 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
{
default:
- case BaseType::Int:
- {
- m_writer->emit("int(");
- m_writer->emit(int32_t(litInst->value.intVal));
- m_writer->emit(")");
- return;
- }
case BaseType::Int8:
{
m_writer->emit("int8_t(");
@@ -1023,6 +1016,14 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
m_writer->emit(")");
return;
}
+ case BaseType::UInt8:
+ {
+ m_writer->emit("uint8_t(");
+ m_writer->emit(UInt(uint8_t(litInst->value.intVal)));
+ m_writer->emit("U");
+ m_writer->emit(")");
+ break;
+ }
case BaseType::Int16:
{
m_writer->emit("int16_t(");
@@ -1030,18 +1031,21 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
m_writer->emit(")");
return;
}
- case BaseType::UInt8:
- {
- m_writer->emit(UInt(uint8_t(litInst->value.intVal)));
- m_writer->emit("U");
- break;
- }
case BaseType::UInt16:
{
+ m_writer->emit("uint16_t(");
m_writer->emit(UInt(uint16_t(litInst->value.intVal)));
m_writer->emit("U");
+ m_writer->emit(")");
break;
}
+ case BaseType::Int:
+ {
+ m_writer->emit("int(");
+ m_writer->emit(int32_t(litInst->value.intVal));
+ m_writer->emit(")");
+ return;
+ }
case BaseType::UInt:
{
m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index a6501b5be..345aa3168 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -10,7 +10,29 @@
namespace Slang {
+static CUDAExtensionTracker::BaseTypeFlags _findBaseTypesUsed(IRModule* module)
+{
+ typedef CUDAExtensionTracker::BaseTypeFlags Flags;
+
+ // All basic types are hoistable so must be in global scope.
+ Flags baseTypesUsed = 0;
+
+ auto moduleInst = module->getModuleInst();
+
+ // Search all the insts in global scope, for BasicTypes
+ for (auto inst : moduleInst->getChildren())
+ {
+ if (auto basicType = as<IRBasicType>(inst))
+ {
+ // Get the base type, and set the bit
+ const auto baseTypeEnum = basicType->getBaseType();
+ baseTypesUsed |= Flags(1) << int(baseTypeEnum);
+ }
+ }
+
+ return baseTypesUsed;
+}
void CUDAExtensionTracker::finalize()
{
@@ -48,12 +70,8 @@ UnownedStringSlice CUDASourceEmitter::getBuiltinTypeName(IROp op)
case kIROp_IntPtrType: return UnownedStringSlice("int");
case kIROp_UIntPtrType: return UnownedStringSlice("uint");
#endif
- case kIROp_HalfType:
- {
- m_extensionTracker->requireBaseType(BaseType::Half);
- return UnownedStringSlice("__half");
- }
-
+ case kIROp_HalfType: return UnownedStringSlice("__half");
+
case kIROp_FloatType: return UnownedStringSlice("float");
case kIROp_DoubleType: return UnownedStringSlice("double");
default: return UnownedStringSlice();
@@ -77,11 +95,7 @@ UnownedStringSlice CUDASourceEmitter::getVectorPrefix(IROp op)
case kIROp_UIntType: return UnownedStringSlice("uint");
case kIROp_UInt64Type: return UnownedStringSlice("ulonglong");
- case kIROp_HalfType:
- {
- m_extensionTracker->requireBaseType(BaseType::Half);
- return UnownedStringSlice("__half");
- }
+ case kIROp_HalfType: return UnownedStringSlice("__half");
case kIROp_FloatType: return UnownedStringSlice("float");
case kIROp_DoubleType: return UnownedStringSlice("double");
@@ -424,8 +438,14 @@ void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operand
void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec)
{
- if (targetIntrinsic->getDefinition().startsWith("__half"))
+ // This works around the problem, where some intrinsics that require the "half" type enabled don't use the half/float16_t type.
+ // For example `f16tof32` can operate on float16_t *and* uint. If the input is uint, although we are
+ // using the half feature (as far as CUDA is concerned), the half/float16_t type is not visible/directly used.
+ if (targetIntrinsic->getDefinition().startsWith(toSlice("__half")))
+ {
m_extensionTracker->requireBaseType(BaseType::Half);
+ }
+
Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec);
}
@@ -795,6 +815,9 @@ bool CUDASourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* v
void CUDASourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink)
{
+ // Set up with all of the base types used in the module
+ m_extensionTracker->requireBaseTypes(_findBaseTypesUsed(module));
+
CLikeSourceEmitter::emitModuleImpl(module, sink);
// Emit all witness table definitions.
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index 47f6eea06..b067e0010 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -17,6 +17,7 @@ public:
void requireBaseType(BaseType baseType) { m_baseTypeFlags |= _getFlag(baseType); }
bool isBaseTypeRequired(BaseType baseType) { return (m_baseTypeFlags & _getFlag(baseType)) != 0; }
+ void requireBaseTypes(BaseTypeFlags flags) { m_baseTypeFlags |= flags; }
/// Ensure that the generated code is compiled for at least CUDA SM `version`
void requireSMVersion(const SemanticVersion& smVersion) { m_smVersion = (smVersion > m_smVersion) ? smVersion : m_smVersion; }
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 1d6fd163c..dbfec9ae7 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -147,7 +147,7 @@ bool opCanBeConstExpr(IRInst* value)
// TODO: realistically need to special-case `call`
// operations here, so that we check whether the
// callee function is fixed/known, and if it is
- // whether it has been decoared as constant-foldable
+ // whether it has been declared as constant-foldable
return opCanBeConstExpr(value->getOp());
}