diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index faf748098..0319d1f7d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -107,6 +107,7 @@ Session::Session() } m_defaultDownstreamCompilers[Index(SourceLanguage::C)] = PassThroughMode::GenericCCpp; m_defaultDownstreamCompilers[Index(SourceLanguage::CPP)] = PassThroughMode::GenericCCpp; + m_defaultDownstreamCompilers[Index(SourceLanguage::CUDA)] = PassThroughMode::NVRTC; } } @@ -193,7 +194,14 @@ SLANG_NO_THROW void SLANG_MCALL Session::setDownstreamCompilerPath( case PassThroughMode::GenericCCpp: { // If any compiler path set changed, require all to be refreshed - cppCompilerSet.setNull(); + downstreamCompilerSet.setNull(); + break; + } + case PassThroughMode::NVRTC: + { + // TODO(JS): We need a way to set the NVRTC path. + // We want to unload... and try again... + downstreamCompilerSet.setNull(); break; } default: break; @@ -249,6 +257,10 @@ static bool _canCompile(PassThroughMode compiler, SourceLanguage sourceLanguage) { return sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP; } + case PassThroughMode::NVRTC: + { + return sourceLanguage == SourceLanguage::CUDA; + } default: break; } return false; @@ -259,13 +271,10 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::setDefaultDownstreamCompiler(Sla auto sourceLanguage = SourceLanguage(inSourceLanguage); auto compiler = PassThroughMode(defaultCompiler); - if (sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP) + if (_canCompile(compiler, sourceLanguage)) { - if (_canCompile(compiler, sourceLanguage)) - { - m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler; - return SLANG_OK; - } + m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler; + return SLANG_OK; } return SLANG_FAIL; @@ -280,19 +289,20 @@ SlangPassThrough SLANG_MCALL Session::getDefaultDownstreamCompiler(SlangSourceLa DownstreamCompiler* Session::getDownstreamCompiler(PassThroughMode compiler) { - DownstreamCompilerSet* compilerSet = requireCPPCompilerSet(); + DownstreamCompilerSet* compilerSet = requireDownstreamCompilerSet(); switch (compiler) { - case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler(); + case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CPP); case PassThroughMode::Clang: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::Clang)); case PassThroughMode::VisualStudio: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::VisualStudio)); case PassThroughMode::Gcc: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::GCC)); + case PassThroughMode::NVRTC: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CUDA); default: break; } return nullptr; } -DownstreamCompiler* Session::getDefaultCPPCompiler(SourceLanguage sourceLanguage) +DownstreamCompiler* Session::getDefaultDownstreamCompiler(SourceLanguage sourceLanguage) { return getDownstreamCompiler(m_defaultDownstreamCompilers[int(sourceLanguage)]); } |
