summaryrefslogtreecommitdiffstats
path: root/source/slang/slang.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang.cpp')
-rw-r--r--source/slang/slang.cpp30
1 files changed, 20 insertions, 10 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index faf748098..0319d1f7d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -107,6 +107,7 @@ Session::Session()
}
m_defaultDownstreamCompilers[Index(SourceLanguage::C)] = PassThroughMode::GenericCCpp;
m_defaultDownstreamCompilers[Index(SourceLanguage::CPP)] = PassThroughMode::GenericCCpp;
+ m_defaultDownstreamCompilers[Index(SourceLanguage::CUDA)] = PassThroughMode::NVRTC;
}
}
@@ -193,7 +194,14 @@ SLANG_NO_THROW void SLANG_MCALL Session::setDownstreamCompilerPath(
case PassThroughMode::GenericCCpp:
{
// If any compiler path set changed, require all to be refreshed
- cppCompilerSet.setNull();
+ downstreamCompilerSet.setNull();
+ break;
+ }
+ case PassThroughMode::NVRTC:
+ {
+ // TODO(JS): We need a way to set the NVRTC path.
+ // We want to unload... and try again...
+ downstreamCompilerSet.setNull();
break;
}
default: break;
@@ -249,6 +257,10 @@ static bool _canCompile(PassThroughMode compiler, SourceLanguage sourceLanguage)
{
return sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP;
}
+ case PassThroughMode::NVRTC:
+ {
+ return sourceLanguage == SourceLanguage::CUDA;
+ }
default: break;
}
return false;
@@ -259,13 +271,10 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::setDefaultDownstreamCompiler(Sla
auto sourceLanguage = SourceLanguage(inSourceLanguage);
auto compiler = PassThroughMode(defaultCompiler);
- if (sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP)
+ if (_canCompile(compiler, sourceLanguage))
{
- if (_canCompile(compiler, sourceLanguage))
- {
- m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler;
- return SLANG_OK;
- }
+ m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler;
+ return SLANG_OK;
}
return SLANG_FAIL;
@@ -280,19 +289,20 @@ SlangPassThrough SLANG_MCALL Session::getDefaultDownstreamCompiler(SlangSourceLa
DownstreamCompiler* Session::getDownstreamCompiler(PassThroughMode compiler)
{
- DownstreamCompilerSet* compilerSet = requireCPPCompilerSet();
+ DownstreamCompilerSet* compilerSet = requireDownstreamCompilerSet();
switch (compiler)
{
- case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler();
+ case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CPP);
case PassThroughMode::Clang: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::Clang));
case PassThroughMode::VisualStudio: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::VisualStudio));
case PassThroughMode::Gcc: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::GCC));
+ case PassThroughMode::NVRTC: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CUDA);
default: break;
}
return nullptr;
}
-DownstreamCompiler* Session::getDefaultCPPCompiler(SourceLanguage sourceLanguage)
+DownstreamCompiler* Session::getDefaultDownstreamCompiler(SourceLanguage sourceLanguage)
{
return getDownstreamCompiler(m_defaultDownstreamCompilers[int(sourceLanguage)]);
}